Skip to content

Commit 5c993b9

Browse files
committed
Add checkpointing
1 parent afa4b5e commit 5c993b9

File tree

10 files changed

+127
-222
lines changed

10 files changed

+127
-222
lines changed

desc/compute/_fast_ion.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def _Gamma_c(params, transforms, profiles, data, **kwargs):
125125
nufft_eps,
126126
spline,
127127
vander,
128-
low_ram,
129128
) = Bounce2D._default_kwargs("weak", grid.NFP, **kwargs)
130129

131130
def Gamma_c(data):
@@ -152,7 +151,6 @@ def fun(pitch_inv):
152151
["|grad(psi)|*kappa_g", "|B|_r|v,p", "K"],
153152
points,
154153
nufft_eps=nufft_eps,
155-
low_ram=low_ram,
156154
is_fourier=True,
157155
)
158156
# This is γ_c π/2.
@@ -278,7 +276,6 @@ def _little_gamma_c_Nemov(params, transforms, profiles, data, **kwargs):
278276
nufft_eps,
279277
spline,
280278
vander,
281-
low_ram,
282279
) = Bounce2D._default_kwargs("weak", grid.NFP, **kwargs)
283280

284281
def gamma_c0(data):
@@ -305,7 +302,6 @@ def fun(pitch_inv):
305302
["|grad(psi)|*kappa_g", "|B|_r|v,p", "K"],
306303
points,
307304
nufft_eps=nufft_eps,
308-
low_ram=low_ram,
309305
is_fourier=True,
310306
)
311307
return (2 / jnp.pi) * jnp.arctan(
@@ -402,7 +398,6 @@ def _Gamma_c_Velasco(params, transforms, profiles, data, **kwargs):
402398
nufft_eps,
403399
spline,
404400
vander,
405-
low_ram,
406401
) = Bounce2D._default_kwargs("weak", grid.NFP, **kwargs)
407402

408403
def Gamma_c(data):
@@ -428,7 +423,6 @@ def fun(pitch_inv):
428423
["cvdrift0", "gbdrift (periodic)", "gbdrift (secular)/phi"],
429424
num_well=num_well,
430425
nufft_eps=nufft_eps,
431-
low_ram=low_ram,
432426
is_fourier=True,
433427
)
434428
# This is γ_c π/2.

desc/compute/_neoclassical.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,6 @@
8080
This private parameter is intended to be used only by
8181
developers for objectives.
8282
""",
83-
"low_ram": """bool :
84-
If true, then will switch to a slower algorithm whose differentiation
85-
consumes less memory. Default is false.
86-
""",
8783
"theta": "",
8884
}
8985

@@ -97,7 +93,6 @@
9793
"surf_batch_size",
9894
"nufft_eps",
9995
"spline",
100-
"low_ram",
10196
)
10297

10398

@@ -201,7 +196,6 @@ def _epsilon_32(params, transforms, profiles, data, **kwargs):
201196
nufft_eps,
202197
spline,
203198
vander,
204-
low_ram,
205199
) = Bounce2D._default_kwargs("deriv", grid.NFP, **kwargs)
206200

207201
def eps_32(data):
@@ -231,7 +225,6 @@ def fun(pitch_inv):
231225
"|grad(rho)|*kappa_g",
232226
num_well=num_well,
233227
nufft_eps=nufft_eps,
234-
low_ram=low_ram,
235228
is_fourier=True,
236229
)
237230
return safediv(I_1**2, I_2).sum(-1).mean(-2)

desc/integrals/_bounce_utils.py

Lines changed: 59 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
from matplotlib import pyplot as plt
2222
from orthax.chebyshev import chebroots, chebvander
2323

24-
from desc.backend import dct, flatnonzero, fori_loop, idct, ifft, jnp
24+
from desc.backend import dct, flatnonzero, fori_loop, idct, ifft, jax, jnp
2525
from desc.integrals._interp_utils import (
2626
_eps,
2727
_filter_distinct,
2828
_subtract_first,
29+
cubic_val,
2930
nufft1d2r,
3031
polyroot_vec,
31-
polyval_vec,
3232
)
3333
from desc.integrals.quad_utils import bijection_from_disc, bijection_to_disc
3434
from desc.io import IOAble
@@ -97,7 +97,7 @@ def _in_epigraph_and(is_intersect, df_dy, /):
9797
return is_intersect.at[idx[0]].set(edge_case)
9898

9999

100-
def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None):
100+
def bounce_points(pitch_inv, knots, B, num_well=None):
101101
"""Compute the bounce points given 1D spline of B and pitch λ.
102102
103103
Parameters
@@ -109,15 +109,10 @@ def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None):
109109
Shape (N, ).
110110
ζ coordinates of spline knots. Must be strictly increasing.
111111
B : jnp.ndarray
112-
Shape (..., N - 1, B.shape[-1]).
112+
Shape (..., N - 1, 4).
113113
Polynomial coefficients of the spline of B in local power basis.
114114
Last axis enumerates the coefficients of power series. Second to
115115
last axis enumerates the polynomials that compose a particular spline.
116-
dB_dz : jnp.ndarray
117-
Shape (..., N - 1, B.shape[-1] - 1).
118-
Polynomial coefficients of the spline of (∂B/∂ζ)|(ρ,α) in local power basis.
119-
Last axis enumerates the coefficients of power series. Second to
120-
last axis enumerates the polynomials that compose a particular spline.
121116
num_well : int or None
122117
Specify to return the first ``num_well`` pairs of bounce points for each
123118
pitch and field line. Choosing ``-1`` will detect all wells, but due
@@ -145,8 +140,9 @@ def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None):
145140
line and pitch, is padded with zero.
146141
147142
"""
143+
B = B[..., None, :, :]
148144
intersect = polyroot_vec(
149-
c=B[..., None, :, :],
145+
c=B,
150146
k=jnp.atleast_1d(pitch_inv)[..., None],
151147
a_min=jnp.array([0.0]),
152148
a_max=jnp.diff(knots),
@@ -156,9 +152,7 @@ def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None):
156152
)
157153
assert intersect.shape[-2:] == (knots.size - 1, B.shape[-1] - 1)
158154

159-
dB_dz = flatten_mat(
160-
jnp.sign(polyval_vec(x=intersect, c=dB_dz[..., None, :, None, :]))
161-
)
155+
dB_dz = flatten_mat(jnp.sign(cubic_val(x=intersect, c=B[..., None, :], der=True)))
162156
# Only consider intersect if it is within knots that bound that polynomial.
163157
mask = flatten_mat(intersect >= 0)
164158
z1 = (dB_dz <= 0) & mask
@@ -532,7 +526,7 @@ def _plot_intersect(
532526
)
533527

534528

535-
def get_extrema(knots, g, dg_dz, sentinel=jnp.nan):
529+
def get_extrema(knots, g, sentinel=jnp.nan):
536530
"""Return extrema (z*, g(z*)).
537531
538532
Parameters
@@ -541,30 +535,28 @@ def get_extrema(knots, g, dg_dz, sentinel=jnp.nan):
541535
Shape (N, ).
542536
ζ coordinates of spline knots. Must be strictly increasing.
543537
g : jnp.ndarray
544-
Shape (..., N - 1, g.shape[-1]).
538+
Shape (..., N - 1, 4).
545539
Polynomial coefficients of the spline of g in local power basis.
546540
Last axis enumerates the coefficients of power series. Second to
547541
last axis enumerates the polynomials that compose a particular spline.
548-
dg_dz : jnp.ndarray
549-
Shape (..., N - 1, g.shape[-1] - 1).
550-
Polynomial coefficients of the spline of ∂g/∂z in local power basis.
551-
Last axis enumerates the coefficients of power series. Second to
552-
last axis enumerates the polynomials that compose a particular spline.
553542
sentinel : float
554543
Value with which to pad array to return fixed shape.
555544
556545
Returns
557546
-------
558547
ext, g_ext : jnp.ndarray
559-
Shape (..., (N - 1) * (g.shape[-1] - 2)).
548+
Shape (..., (N - 1) * 2).
560549
First array enumerates z*. Second array enumerates g(z*)
561550
Sorting order of extrema is arbitrary.
562551
563552
"""
564553
ext = polyroot_vec(
565-
c=dg_dz, a_min=jnp.array([0.0]), a_max=jnp.diff(knots), sentinel=sentinel
554+
c=g[..., :-1] * jnp.arange(g.shape[-1] - 1, 0, -1),
555+
a_min=jnp.array([0.0]),
556+
a_max=jnp.diff(knots),
557+
sentinel=sentinel,
566558
)
567-
g_ext = flatten_mat(polyval_vec(x=ext, c=g[..., None, :]))
559+
g_ext = flatten_mat(cubic_val(x=ext, c=g[..., None, :]))
568560
# Transform out of local power basis expansion.
569561
ext = flatten_mat(ext + knots[:-1, None])
570562
assert ext.shape == g_ext.shape
@@ -991,23 +983,6 @@ def round_up_rule(Y, NFP, axisymmetric=False):
991983
return num_z * NFP, num_z
992984

993985

994-
def fieldline_quad_rule(Y):
995-
"""Ensure field line quadrature has reasonable resolution.
996-
997-
Parameters
998-
----------
999-
Y : int
1000-
Resolution of Chebyshev spectrum of angle over one field period.
1001-
1002-
Returns
1003-
-------
1004-
Y : int
1005-
Resolution for Gauss-Legendre quadrature over one field period.
1006-
1007-
"""
1008-
return max(Y, 8)
1009-
1010-
1011986
def Y_B_rule(Y, NFP, spline=True):
1012987
"""Guess Y_B from resolution of Chebyshev spectrum of angle."""
1013988
return (2 * Y * int(np.sqrt(NFP))) if spline else Y
@@ -1037,6 +1012,45 @@ def get_vander(grid, Y, Y_B, NFP):
10371012
return {"dct spline": chebvander(x, Y_trunc - 1)}
10381013

10391014

1015+
@jax.checkpoint
1016+
def _gather_reduce(y, cheb, x_idx):
1017+
"""Gather then reduce with checkpointing.
1018+
1019+
Checkpointing this makes it faster and reduces memory.
1020+
On many architectures, the cost of allocating contiguous blocks of memory
1021+
exceeds the cost of flops. By checkpointing, we avoid that in favor of
1022+
recomputing derivatives in the backward pass.
1023+
1024+
Lies:
1025+
https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#practical-notes
1026+
"""
1027+
return idct_mmt(y, jnp.take_along_axis(cheb, x_idx[..., None], axis=-2))
1028+
1029+
1030+
@jax.checkpoint
1031+
def _loop(y, cheb, x_idx):
1032+
"""Memory efficient Clenshaw recursion.
1033+
1034+
Checkpointing on a CPU observed a minor reduction in memory usage while not
1035+
affecting speed. JAX/XLA has poor performance with iterative algorithms compared
1036+
to languages like Julia. On JAX version 0.7.2, this is slower than the product sum
1037+
reduction above. Without checkpointing either, this uses signficantly less memory
1038+
than above.
1039+
"""
1040+
1041+
def body(i, val):
1042+
c0, c1 = val
1043+
return jnp.take_along_axis(cheb[-i], x_idx, axis=-1) - c1, c0 + c1 * y2
1044+
1045+
num_coef = cheb.shape[-1]
1046+
cheb = jnp.moveaxis(cheb, -1, 0) # to minimize cache misses
1047+
y2 = 2 * y
1048+
c0 = jnp.take_along_axis(cheb[-2], x_idx, axis=-1)
1049+
c1 = jnp.take_along_axis(cheb[-1], x_idx, axis=-1)
1050+
c0, c1 = fori_loop(3, num_coef + 1, body, (c0, c1))
1051+
return c0 + c1 * y
1052+
1053+
10401054
class PiecewiseChebyshevSeries(IOAble):
10411055
"""Chebyshev series.
10421056
@@ -1167,8 +1181,10 @@ def eval1d(self, z, cheb=None, loop=False):
11671181
Shape (..., X, Y).
11681182
Chebyshev coefficients to use. If not given, uses ``self.cheb``.
11691183
loop : bool
1170-
Whether to use Clenshaw recursion.
1171-
This is slower on CPU, but it reduces memory of the Jacobian.
1184+
If ``True``, then uses Clenshaw recursion which is memory efficient.
1185+
If ``False``, then gathers a large block of memory and computes
1186+
a product sum reduction while checkpointing the derivative
1187+
to reduce memory consumption of the Jacobian.
11721188
11731189
Returns
11741190
-------
@@ -1178,26 +1194,8 @@ def eval1d(self, z, cheb=None, loop=False):
11781194
"""
11791195
cheb = setdefault(cheb, self.cheb)
11801196
x_idx, y = self._isomorphism_to_C2(z)
1181-
11821197
y = bijection_to_disc(y, self.domain[0], self.domain[-1])
1183-
1184-
# Recall that the Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z])
1185-
# are in cheb array whose shape is (..., num cheb series, spectral resolution).
1186-
1187-
if not loop or self.Y < 3:
1188-
cheb = jnp.take_along_axis(cheb, x_idx[..., None], axis=-2)
1189-
return idct_mmt(y, cheb)
1190-
1191-
def body(i, val):
1192-
c0, c1 = val
1193-
return jnp.take_along_axis(cheb[-i], x_idx, axis=-1) - c1, c0 + c1 * y2
1194-
1195-
cheb = jnp.moveaxis(cheb, -1, 0) # to leverage cache
1196-
y2 = 2 * y
1197-
c0 = jnp.take_along_axis(cheb[-2], x_idx, axis=-1)
1198-
c1 = jnp.take_along_axis(cheb[-1], x_idx, axis=-1)
1199-
c0, c1 = fori_loop(3, self.Y + 1, body, (c0, c1))
1200-
return c0 + c1 * y
1198+
return (_loop if (loop and self.Y >= 3) else _gather_reduce)(y, cheb, x_idx)
12011199

12021200
def intersect2d(self, k=0.0, *, eps=_eps):
12031201
"""Coordinates yᵢ such that f(x, yᵢ) = k(x).

desc/integrals/_interp_utils.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -161,32 +161,8 @@ def interp1d_Hermite_vec(xq, x, f, fx, /):
161161
return interp1d(xq, x, f, method="cubic", fx=fx)
162162

163163

164-
# TODO (#1388): Move to interpax.
165-
166-
167-
def polyder_vec(c):
168-
"""Coefficients for the derivatives of the given set of polynomials.
169-
170-
Parameters
171-
----------
172-
c : jnp.ndarray
173-
Last axis should store coefficients of a polynomial. For a polynomial given by
174-
∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at
175-
``c[...,n-i]``.
176-
177-
Returns
178-
-------
179-
poly : jnp.ndarray
180-
Coefficients of polynomial derivative, ignoring the arbitrary constant. That is,
181-
``poly[...,i]`` stores the coefficient of the monomial xⁿ⁻ⁱ⁻¹, where n is
182-
``c.shape[-1]-1``.
183-
184-
"""
185-
return c[..., :-1] * jnp.arange(c.shape[-1] - 1, 0, -1)
186-
187-
188-
def polyval_vec(*, x, c):
189-
"""Evaluate the set of polynomials ``c`` at the points ``x``.
164+
def cubic_val(*, x, c, der=False):
165+
"""Evaluate the derivative of cubic polynomial ``c`` at the points ``x``.
190166
191167
Parameters
192168
----------
@@ -196,6 +172,8 @@ def polyval_vec(*, x, c):
196172
Last axis should store coefficients of a polynomial. For a polynomial given by
197173
∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at
198174
``c[...,n-i]``.
175+
der : bool
176+
Whether to evaluate the derivative instead.
199177
200178
Returns
201179
-------
@@ -207,17 +185,16 @@ def polyval_vec(*, x, c):
207185
.. code-block:: python
208186
209187
np.testing.assert_allclose(
210-
polyval_vec(x=x, c=c),
188+
cubic_val(x=x, c=c),
211189
np.sum(polyvander(x, c.shape[-1] - 1) * c[..., ::-1], axis=-1),
212190
)
213191
214192
"""
215-
# Better than Horner's method as we expect to evaluate low order polynomials.
216-
# No need to use fast multipoint evaluation techniques for the same reason.
217-
return jnp.sum(
218-
c * x[..., jnp.newaxis] ** jnp.arange(c.shape[-1] - 1, -1, -1),
219-
axis=-1,
220-
)
193+
assert c.shape[-1] == 4
194+
if der:
195+
return (3 * c[..., 0] * x + 2 * c[..., 1]) * x + c[..., 2]
196+
else:
197+
return ((c[..., 0] * x + c[..., 1]) * x + c[..., 2]) * x + c[..., 3]
221198

222199

223200
def _subtract_first(c, k):
@@ -265,6 +242,8 @@ def _filter_distinct(r, sentinel, eps):
265242
)
266243
_eps = max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12)
267244

245+
# TODO (#1388): Move to interpax.
246+
268247

269248
def polyroot_vec(
270249
c,

0 commit comments

Comments
 (0)