2121from matplotlib import pyplot as plt
2222from orthax .chebyshev import chebroots , chebvander
2323
24- from desc .backend import dct , flatnonzero , fori_loop , idct , ifft , jnp
24+ from desc .backend import dct , flatnonzero , fori_loop , idct , ifft , jax , jnp
2525from desc .integrals ._interp_utils import (
2626 _eps ,
2727 _filter_distinct ,
2828 _subtract_first ,
29+ cubic_val ,
2930 nufft1d2r ,
3031 polyroot_vec ,
31- polyval_vec ,
3232)
3333from desc .integrals .quad_utils import bijection_from_disc , bijection_to_disc
3434from desc .io import IOAble
@@ -97,7 +97,7 @@ def _in_epigraph_and(is_intersect, df_dy, /):
9797 return is_intersect .at [idx [0 ]].set (edge_case )
9898
9999
100- def bounce_points (pitch_inv , knots , B , dB_dz , num_well = None ):
100+ def bounce_points (pitch_inv , knots , B , num_well = None ):
101101 """Compute the bounce points given 1D spline of B and pitch λ.
102102
103103 Parameters
@@ -109,15 +109,10 @@ def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None):
109109 Shape (N, ).
110110 ζ coordinates of spline knots. Must be strictly increasing.
111111 B : jnp.ndarray
112- Shape (..., N - 1, B.shape[-1] ).
112+ Shape (..., N - 1, 4 ).
113113 Polynomial coefficients of the spline of B in local power basis.
114114 Last axis enumerates the coefficients of power series. Second to
115115 last axis enumerates the polynomials that compose a particular spline.
116- dB_dz : jnp.ndarray
117- Shape (..., N - 1, B.shape[-1] - 1).
118- Polynomial coefficients of the spline of (∂B/∂ζ)|(ρ,α) in local power basis.
119- Last axis enumerates the coefficients of power series. Second to
120- last axis enumerates the polynomials that compose a particular spline.
121116 num_well : int or None
122117 Specify to return the first ``num_well`` pairs of bounce points for each
123118 pitch and field line. Choosing ``-1`` will detect all wells, but due
@@ -145,8 +140,9 @@ def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None):
145140 line and pitch, is padded with zero.
146141
147142 """
143+ B = B [..., None , :, :]
148144 intersect = polyroot_vec (
149- c = B [..., None , :, :] ,
145+ c = B ,
150146 k = jnp .atleast_1d (pitch_inv )[..., None ],
151147 a_min = jnp .array ([0.0 ]),
152148 a_max = jnp .diff (knots ),
@@ -156,9 +152,7 @@ def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None):
156152 )
157153 assert intersect .shape [- 2 :] == (knots .size - 1 , B .shape [- 1 ] - 1 )
158154
159- dB_dz = flatten_mat (
160- jnp .sign (polyval_vec (x = intersect , c = dB_dz [..., None , :, None , :]))
161- )
155+ dB_dz = flatten_mat (jnp .sign (cubic_val (x = intersect , c = B [..., None , :], der = True )))
162156 # Only consider intersect if it is within knots that bound that polynomial.
163157 mask = flatten_mat (intersect >= 0 )
164158 z1 = (dB_dz <= 0 ) & mask
@@ -532,7 +526,7 @@ def _plot_intersect(
532526 )
533527
534528
535- def get_extrema (knots , g , dg_dz , sentinel = jnp .nan ):
529+ def get_extrema (knots , g , sentinel = jnp .nan ):
536530 """Return extrema (z*, g(z*)).
537531
538532 Parameters
@@ -541,30 +535,28 @@ def get_extrema(knots, g, dg_dz, sentinel=jnp.nan):
541535 Shape (N, ).
542536 ζ coordinates of spline knots. Must be strictly increasing.
543537 g : jnp.ndarray
544- Shape (..., N - 1, g.shape[-1] ).
538+ Shape (..., N - 1, 4 ).
545539 Polynomial coefficients of the spline of g in local power basis.
546540 Last axis enumerates the coefficients of power series. Second to
547541 last axis enumerates the polynomials that compose a particular spline.
548- dg_dz : jnp.ndarray
549- Shape (..., N - 1, g.shape[-1] - 1).
550- Polynomial coefficients of the spline of ∂g/∂z in local power basis.
551- Last axis enumerates the coefficients of power series. Second to
552- last axis enumerates the polynomials that compose a particular spline.
553542 sentinel : float
554543 Value with which to pad array to return fixed shape.
555544
556545 Returns
557546 -------
558547 ext, g_ext : jnp.ndarray
559- Shape (..., (N - 1) * (g.shape[-1] - 2) ).
548+ Shape (..., (N - 1) * 2 ).
560549 First array enumerates z*. Second array enumerates g(z*)
561550 Sorting order of extrema is arbitrary.
562551
563552 """
564553 ext = polyroot_vec (
565- c = dg_dz , a_min = jnp .array ([0.0 ]), a_max = jnp .diff (knots ), sentinel = sentinel
554+ c = g [..., :- 1 ] * jnp .arange (g .shape [- 1 ] - 1 , 0 , - 1 ),
555+ a_min = jnp .array ([0.0 ]),
556+ a_max = jnp .diff (knots ),
557+ sentinel = sentinel ,
566558 )
567- g_ext = flatten_mat (polyval_vec (x = ext , c = g [..., None , :]))
559+ g_ext = flatten_mat (cubic_val (x = ext , c = g [..., None , :]))
568560 # Transform out of local power basis expansion.
569561 ext = flatten_mat (ext + knots [:- 1 , None ])
570562 assert ext .shape == g_ext .shape
@@ -991,23 +983,6 @@ def round_up_rule(Y, NFP, axisymmetric=False):
991983 return num_z * NFP , num_z
992984
993985
994- def fieldline_quad_rule (Y ):
995- """Ensure field line quadrature has reasonable resolution.
996-
997- Parameters
998- ----------
999- Y : int
1000- Resolution of Chebyshev spectrum of angle over one field period.
1001-
1002- Returns
1003- -------
1004- Y : int
1005- Resolution for Gauss-Legendre quadrature over one field period.
1006-
1007- """
1008- return max (Y , 8 )
1009-
1010-
1011986def Y_B_rule (Y , NFP , spline = True ):
1012987 """Guess Y_B from resolution of Chebyshev spectrum of angle."""
1013988 return (2 * Y * int (np .sqrt (NFP ))) if spline else Y
@@ -1037,6 +1012,45 @@ def get_vander(grid, Y, Y_B, NFP):
10371012 return {"dct spline" : chebvander (x , Y_trunc - 1 )}
10381013
10391014
1015+ @jax .checkpoint
1016+ def _gather_reduce (y , cheb , x_idx ):
1017+ """Gather then reduce with checkpointing.
1018+
1019+ Checkpointing this makes it faster and reduces memory.
1020+ On many architectures, the cost of allocating contiguous blocks of memory
1021+ exceeds the cost of flops. By checkpointing, we avoid that in favor of
1022+ recomputing derivatives in the backward pass.
1023+
1024+ Lies:
1025+ https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#practical-notes
1026+ """
1027+ return idct_mmt (y , jnp .take_along_axis (cheb , x_idx [..., None ], axis = - 2 ))
1028+
1029+
1030+ @jax .checkpoint
1031+ def _loop (y , cheb , x_idx ):
1032+ """Memory efficient Clenshaw recursion.
1033+
1034+ Checkpointing on a CPU observed a minor reduction in memory usage while not
1035+ affecting speed. JAX/XLA has poor performance with iterative algorithms compared
1036+ to languages like Julia. On JAX version 0.7.2, this is slower than the product sum
1037+ reduction above. Without checkpointing either, this uses signficantly less memory
1038+ than above.
1039+ """
1040+
1041+ def body (i , val ):
1042+ c0 , c1 = val
1043+ return jnp .take_along_axis (cheb [- i ], x_idx , axis = - 1 ) - c1 , c0 + c1 * y2
1044+
1045+ num_coef = cheb .shape [- 1 ]
1046+ cheb = jnp .moveaxis (cheb , - 1 , 0 ) # to minimize cache misses
1047+ y2 = 2 * y
1048+ c0 = jnp .take_along_axis (cheb [- 2 ], x_idx , axis = - 1 )
1049+ c1 = jnp .take_along_axis (cheb [- 1 ], x_idx , axis = - 1 )
1050+ c0 , c1 = fori_loop (3 , num_coef + 1 , body , (c0 , c1 ))
1051+ return c0 + c1 * y
1052+
1053+
10401054class PiecewiseChebyshevSeries (IOAble ):
10411055 """Chebyshev series.
10421056
@@ -1167,8 +1181,10 @@ def eval1d(self, z, cheb=None, loop=False):
11671181 Shape (..., X, Y).
11681182 Chebyshev coefficients to use. If not given, uses ``self.cheb``.
11691183 loop : bool
1170- Whether to use Clenshaw recursion.
1171- This is slower on CPU, but it reduces memory of the Jacobian.
1184+ If ``True``, then uses Clenshaw recursion which is memory efficient.
1185+ If ``False``, then gathers a large block of memory and computes
1186+ a product sum reduction while checkpointing the derivative
1187+ to reduce memory consumption of the Jacobian.
11721188
11731189 Returns
11741190 -------
@@ -1178,26 +1194,8 @@ def eval1d(self, z, cheb=None, loop=False):
11781194 """
11791195 cheb = setdefault (cheb , self .cheb )
11801196 x_idx , y = self ._isomorphism_to_C2 (z )
1181-
11821197 y = bijection_to_disc (y , self .domain [0 ], self .domain [- 1 ])
1183-
1184- # Recall that the Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z])
1185- # are in cheb array whose shape is (..., num cheb series, spectral resolution).
1186-
1187- if not loop or self .Y < 3 :
1188- cheb = jnp .take_along_axis (cheb , x_idx [..., None ], axis = - 2 )
1189- return idct_mmt (y , cheb )
1190-
1191- def body (i , val ):
1192- c0 , c1 = val
1193- return jnp .take_along_axis (cheb [- i ], x_idx , axis = - 1 ) - c1 , c0 + c1 * y2
1194-
1195- cheb = jnp .moveaxis (cheb , - 1 , 0 ) # to leverage cache
1196- y2 = 2 * y
1197- c0 = jnp .take_along_axis (cheb [- 2 ], x_idx , axis = - 1 )
1198- c1 = jnp .take_along_axis (cheb [- 1 ], x_idx , axis = - 1 )
1199- c0 , c1 = fori_loop (3 , self .Y + 1 , body , (c0 , c1 ))
1200- return c0 + c1 * y
1198+ return (_loop if (loop and self .Y >= 3 ) else _gather_reduce )(y , cheb , x_idx )
12011199
12021200 def intersect2d (self , k = 0.0 , * , eps = _eps ):
12031201 """Coordinates yᵢ such that f(x, yᵢ) = k(x).
0 commit comments