[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JAX error with FourierCurrentPotentialField in Flux objectives #1002

Open
wants to merge 30 commits into
base: master
Choose a base branch
from

Conversation

dpanici
Copy link
Collaborator
@dpanici dpanici commented Apr 19, 2024

On master, an optimization with a FourierCurrentPotentialField and the QuadraticFlux objective would fail. This happens because in QuadraticFlux.compute, field.compute_magnetic_field is called. If the field needs transforms to evaluate, then these transforms will be created on the fly if they are not provided, resulting in an error.

This PR fixes that by adding jitable=True to the .compute call inside .compute_magnetic_field

It also adds transform as an argument to MagneticField.compute_magnetic_field methods, in preparation for when #1079 is resolved and objectives can pre-compute and pass in transform objects for all magnetic fields easily.

This PR kicks the can down the road of making get_transforms work with MagneticField.compute_magnetic_field objects (to allow us to pre-compute the transforms used for magnetic field computation #1079 ) and instead opts for the simple fix.

Copy link
Contributor
github-actions bot commented Apr 19, 2024
|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +7.06 +/- 8.20     | +3.58e-02 +/- 4.16e-02 |  5.43e-01 +/- 4.1e-02  |  5.07e-01 +/- 6.4e-03  |
 test_build_transform_fft_midres         |     +6.67 +/- 3.29     | +3.96e-02 +/- 1.95e-02 |  6.33e-01 +/- 1.7e-02  |  5.93e-01 +/- 9.9e-03  |
 test_build_transform_fft_highres        |     +3.78 +/- 3.51     | +3.73e-02 +/- 3.46e-02 |  1.02e+00 +/- 1.5e-02  |  9.87e-01 +/- 3.1e-02  |
 test_equilibrium_init_lowres            |     +1.75 +/- 2.66     | +6.43e-02 +/- 9.76e-02 |  3.73e+00 +/- 9.7e-02  |  3.67e+00 +/- 9.8e-03  |
 test_equilibrium_init_medres            |     +1.82 +/- 2.17     | +7.53e-02 +/- 8.98e-02 |  4.22e+00 +/- 8.9e-02  |  4.14e+00 +/- 1.0e-02  |
 test_equilibrium_init_highres           |     +1.21 +/- 1.88     | +6.71e-02 +/- 1.04e-01 |  5.60e+00 +/- 8.5e-02  |  5.54e+00 +/- 6.0e-02  |
 test_objective_compile_dshape_current   |     +0.64 +/- 5.96     | +2.43e-02 +/- 2.26e-01 |  3.81e+00 +/- 2.2e-01  |  3.79e+00 +/- 1.9e-02  |
 test_objective_compile_atf              |     +1.09 +/- 3.08     | +8.93e-02 +/- 2.54e-01 |  8.32e+00 +/- 1.5e-01  |  8.23e+00 +/- 2.0e-01  |
 test_objective_compute_dshape_current   |     -1.81 +/- 4.61     | -2.31e-05 +/- 5.87e-05 |  1.25e-03 +/- 2.9e-05  |  1.27e-03 +/- 5.1e-05  |
 test_objective_compute_atf              |     +0.42 +/- 6.38     | +1.76e-05 +/- 2.70e-04 |  4.25e-03 +/- 2.3e-04  |  4.23e-03 +/- 1.5e-04  |
 test_objective_jac_dshape_current       |     -1.67 +/- 11.51    | -6.06e-04 +/- 4.18e-03 |  3.57e-02 +/- 2.3e-03  |  3.64e-02 +/- 3.5e-03  |
 test_objective_jac_atf                  |     +3.58 +/- 2.49     | +6.67e-02 +/- 4.63e-02 |  1.93e+00 +/- 3.4e-02  |  1.86e+00 +/- 3.2e-02  |
 test_perturb_1                          |     +0.51 +/- 0.68     | +6.71e-02 +/- 8.89e-02 |  1.31e+01 +/- 7.4e-02  |  1.30e+01 +/- 4.9e-02  |
 test_perturb_2                          |     +0.74 +/- 1.23     | +1.33e-01 +/- 2.22e-01 |  1.81e+01 +/- 1.3e-01  |  1.80e+01 +/- 1.8e-01  |
 test_proximal_jac_atf                   |     -0.61 +/- 1.44     | -4.48e-02 +/- 1.05e-01 |  7.28e+00 +/- 6.0e-02  |  7.32e+00 +/- 8.7e-02  |
 test_proximal_freeb_compute             |     -1.67 +/- 0.74     | -3.00e-03 +/- 1.32e-03 |  1.76e-01 +/- 1.0e-03  |  1.79e-01 +/- 8.2e-04  |
 test_proximal_freeb_jac                 |     -0.12 +/- 1.09     | -8.66e-03 +/- 7.98e-02 |  7.31e+00 +/- 5.9e-02  |  7.32e+00 +/- 5.4e-02  |
 test_solve_fixed_iter                   |     -0.38 +/- 9.08     | -5.59e-02 +/- 1.34e+00 |  1.47e+01 +/- 9.9e-01  |  1.48e+01 +/- 9.0e-01  |

Copy link
codecov bot commented Apr 19, 2024

Codecov Report

Attention: Patch coverage is 78.94737% with 4 lines in your changes missing coverage. Please review.

Project coverage is 94.97%. Comparing base (0a6b995) to head (63e1cd9).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1002      +/-   ##
==========================================
- Coverage   94.98%   94.97%   -0.01%     
==========================================
  Files          87       87              
  Lines       21749    21762      +13     
==========================================
+ Hits        20658    20669      +11     
- Misses       1091     1093       +2     
Files Coverage Δ
desc/geometry/core.py 95.91% <100.00%> (ø)
desc/objectives/_coils.py 99.12% <ø> (ø)
desc/coils.py 96.88% <85.71%> (+0.02%) ⬆️
desc/magnetic_fields/_current_potential.py 98.83% <75.00%> (-0.58%) ⬇️
desc/magnetic_fields/_core.py 96.40% <71.42%> (-0.27%) ⬇️

... and 1 file with indirect coverage changes

Copy link
Member
@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason coils worked while the current potential doesnt is that by default the coil classes create jitable transforms in their compute method, since they're only 1d its not really worth it to try to do the fft.

I'd recommend adding the new classes to the existing get_transforms logic, I think most of it should "just work" assuming attributes are named correctly. (might need some logic somewhere for tree-like coilsets etc, but do-able.

@@ -753,13 +753,28 @@ def build(self, use_jit=True, verbose=1):
Bplasma = compute_B_plasma(
eq, eval_grid, self._source_grid, normal_only=True
)
field = self._field
if hasattr(field, "Phi_mn"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to get this all to work with the existing desc.compute.utils.get_transforms helper? That would work for coils etc as well.

@dpanici dpanici marked this pull request as draft April 25, 2024 15:12
@dpanici
Copy link
Collaborator Author
dpanici commented Apr 30, 2024

I think the reason coils worked while the current potential doesnt is that by default the coil classes create jitable transforms in their compute method, since they're only 1d its not really worth it to try to do the fft.

I'd recommend adding the new classes to the existing get_transforms logic, I think most of it should "just work" assuming attributes are named correctly. (might need some logic somewhere for tree-like coilsets etc, but do-able.

The problem is less the classes and more that not every magnetic field class needs the same keys to calculate the magnetic field. Coils need "I" and "x", the current potential fields need "K", toroidal fields dont need any keys as they just have their own compute magnetic field function.

@dpanici
Copy link
Collaborator Author
dpanici commented May 1, 2024

add case for coils and 1D transforms (by calling get_transforms in build always asking for "x" (if is a curve class or CoilSet) and maybe also "K" if it is a current potential field.

Also add in the current potential field compute, a jitable=True (as if there is a current potential field in a SumMagneticField it might not work correctly...)

tests/test_optimizer.py Outdated Show resolved Hide resolved
tests/test_optimizer.py Outdated Show resolved Hide resolved
@dpanici
Copy link
Collaborator Author
dpanici commented Jun 10, 2024

@f0uriest I don't exactly remember the issue with the current method/how putting the logic in get_transforms would improve it.

We still would need to call get_transforms with the correct grid based off of what the underlying object is (which would still require a check in the build method to see if it is a coil or a field), and to ask for the correct keys ("x" and maybe "K" as well if it is a current potential field).

I think I also found in this PR a fix to something causing a JAX error in a few places when current potential fields are being optimized, so if possible I'd like to get this one in. Unless there is a better case for putting more logic into get_transforms, which I think should wait until we make "B" a data index quantity for MagneticField objects and then can go through that route to get all the proper transforms etc based off of the specific MagneticField parameterization

@dpanici dpanici marked this pull request as ready for review June 10, 2024 17:48
desc/coils.py Show resolved Hide resolved
desc/objectives/_coils.py Outdated Show resolved Hide resolved
desc/objectives/_coils.py Outdated Show resolved Hide resolved


@pytest.mark.unit
def test_tor_flux_with_surface_current_field():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these tests be one test with an inner test() function or is it better to keep them separate?

desc/coils.py Show resolved Hide resolved
@@ -644,7 +647,23 @@ def _compute_magnetic_field_from_CurrentPotentialField(
# compute surface current, and store grid quantities
# needed for integration in class
# TODO: does this have to be xyz, or can it be computed in rpz as well?
data = field.compute(["K", "x"], grid=source_grid, basis="xyz", params=params)
if not params and not transforms:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this conditional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid calling field.compute inside a compute function of an objective, as there is some logic there that jit does not like

@dpanici
Copy link
Collaborator Author
dpanici commented Jun 19, 2024

Equinox for runtime error checking in jitted functions

@dpanici
Copy link
Collaborator Author
dpanici commented Jun 19, 2024

@f0uriest check this again for placing the code in a util function

@dpanici
Copy link
Collaborator Author
dpanici commented Jul 2, 2024

@f0uriest I think with this PR I will remove the changes I made in the objective build methods, leave the new transform kwargs I put in all the fields, and have that be the only change here (allow transforms to be passed into MagneticField compute methods), along with the simple fix of just putting jitable in the right spots. Then we can have a separate PR when we decide how to properly handle transforms for B with get_transforms

I'd like this as the course of action as embarrasingly I already had merged this to #579 and so it is hard to disentangle the two...

@dpanici dpanici requested review from f0uriest, ddudt and kianorr and removed request for f0uriest, ddudt and YigitElma July 2, 2024 14:36
@dpanici dpanici changed the title Add transform pre-computation for FourierCurrentPotentialField in Flux objectives Fix JAX error with FourierCurrentPotentialField in Flux objectives Jul 2, 2024
@dpanici dpanici added the EZ-review This PR takes less than 30 mins to review label Jul 2, 2024
rahulgaur104
rahulgaur104 previously approved these changes Jul 2, 2024
@rahulgaur104
Copy link
Collaborator

Merge after increasing the coverage.

kianorr
kianorr previously approved these changes Jul 2, 2024
@dpanici dpanici dismissed stale reviews from kianorr and rahulgaur104 via 07d110b July 3, 2024 00:02
tests/test_optimizer.py Show resolved Hide resolved


@pytest.mark.unit
def test_quad_flux_with_surface_current_field():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to remove/combine this test with #1025 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, as that one will add coil tests while this one is using surface current fields. If one test on that branch involves a sum magnetic field with coils and a surface current field then we can remove this in favor of that one

@dpanici dpanici requested a review from daniel-dudt July 4, 2024 15:33
@@ -1272,7 +1272,10 @@ def compute(self, field_params, constants=None):

# B_ext is not pre-computed because field is not fixed
B_ext = constants["field"].compute_magnetic_field(
x, source_grid=constants["field_grid"], basis="rpz", params=field_params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No real change to this file, right? You just like to make the code longer to bother me?

kianorr added a commit that referenced this pull request Jul 4, 2024
- doesn't pass due to jax bug, maybe PR #1002 will fix it?
@kianorr kianorr mentioned this pull request Jul 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
EZ-review This PR takes less than 30 mins to review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants