Skip to content

Commit 887213f

Browse files
authored
Merge branch 'PlasmaControl:master' into rg/adjoint_ballooning
2 parents 9acd5fa + 61797b6 commit 887213f

File tree

117 files changed

+4280
-2501
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+4280
-2501
lines changed

.test_durations

Lines changed: 443 additions & 418 deletions
Large diffs are not rendered by default.

desc/backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def fori_loop(lower, upper, body_fun, init_val):
201201
val = body_fun(i, val)
202202
return val
203203

204-
def cond(pred, true_fun, false_fun, operand):
204+
def cond(pred, true_fun, false_fun, *operand):
205205
"""Conditionally apply true_fun or false_fun.
206206
207207
This version is for the numpy backend, for jax backend see jax.lax.cond
@@ -227,9 +227,9 @@ def cond(pred, true_fun, false_fun, operand):
227227
228228
"""
229229
if pred:
230-
return true_fun(operand)
230+
return true_fun(*operand)
231231
else:
232-
return false_fun(operand)
232+
return false_fun(*operand)
233233

234234
def switch(index, branches, operand):
235235
"""Apply exactly one of branches given by index.

desc/compute/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
Parameters
66
----------
77
params : dict of ndarray
8-
Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc
8+
Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc.
99
transforms : dict of Transform
10-
Transforms for R, Z, lambda, etc
10+
Transforms for R, Z, lambda, etc.
1111
profiles : dict of Profile
12-
Profile objects for pressure, iota, current, etc
12+
Profile objects for pressure, iota, current, etc.
1313
data : dict of ndarray
1414
Data computed so far, generally output from other compute functions
1515
kwargs : dict
@@ -59,8 +59,8 @@
5959
# import the compute module.
6060
def _build_data_index():
6161

62-
for p in data_index.keys():
63-
for key in data_index[p].keys():
62+
for p in data_index:
63+
for key in data_index[p]:
6464
full = {
6565
"data": get_data_deps(key, p, has_axis=False),
6666
"transforms": get_derivs(key, p, has_axis=False),

desc/compute/_basis_vectors.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
transforms={},
2727
profiles=[],
2828
coordinates="rtz",
29-
data=["B"],
29+
data=["B", "|B|"],
3030
)
3131
def _b(params, transforms, profiles, data, **kwargs):
32-
data["b"] = (data["B"].T / jnp.linalg.norm(data["B"], axis=-1)).T
32+
data["b"] = (data["B"].T / data["|B|"]).T
3333
return data
3434

3535

@@ -47,6 +47,8 @@ def _b(params, transforms, profiles, data, **kwargs):
4747
data=["e_theta/sqrt(g)", "e_zeta"],
4848
)
4949
def _e_sup_rho(params, transforms, profiles, data, **kwargs):
50+
# At the magnetic axis, this function returns the multivalued map whose
51+
# image is the set { 𝐞^ρ | ρ=0 }.
5052
data["e^rho"] = cross(data["e_theta/sqrt(g)"], data["e_zeta"])
5153
return data
5254

@@ -196,8 +198,14 @@ def _e_sup_theta(params, transforms, profiles, data, **kwargs):
196198
profiles=[],
197199
coordinates="rtz",
198200
data=["e_rho", "e_zeta"],
201+
parameterization=[
202+
"desc.equilibrium.equilibrium.Equilibrium",
203+
"desc.geometry.core.Surface",
204+
],
199205
)
200206
def _e_sup_theta_times_sqrt_g(params, transforms, profiles, data, **kwargs):
207+
# At the magnetic axis, this function returns the multivalued map whose
208+
# image is the set { 𝐞^θ √g | ρ=0 }.
201209
data["e^theta*sqrt(g)"] = cross(data["e_zeta"], data["e_rho"])
202210
return data
203211

@@ -299,6 +307,8 @@ def _e_sup_theta_z(params, transforms, profiles, data, **kwargs):
299307
data=["e_rho", "e_theta/sqrt(g)"],
300308
)
301309
def _e_sup_zeta(params, transforms, profiles, data, **kwargs):
310+
# At the magnetic axis, this function returns the multivalued map whose
311+
# image is the set { 𝐞^ζ | ρ=0 }.
302312
data["e^zeta"] = cross(data["e_rho"], data["e_theta/sqrt(g)"])
303313
return data
304314

@@ -453,6 +463,8 @@ def _e_sub_phi(params, transforms, profiles, data, **kwargs):
453463
data=["R", "R_r", "Z_r", "omega_r"],
454464
)
455465
def _e_sub_rho(params, transforms, profiles, data, **kwargs):
466+
# At the magnetic axis, this function returns the multivalued map whose
467+
# image is the set { 𝐞ᵨ | ρ=0 }.
456468
data["e_rho"] = jnp.array([data["R_r"], data["R"] * data["omega_r"], data["Z_r"]]).T
457469
return data
458470

@@ -1386,6 +1398,8 @@ def _e_sub_theta(params, transforms, profiles, data, **kwargs):
13861398
axis_limit_data=["e_theta_r", "sqrt(g)_r"],
13871399
)
13881400
def _e_sub_theta_over_sqrt_g(params, transforms, profiles, data, **kwargs):
1401+
# At the magnetic axis, this function returns the multivalued map whose
1402+
# image is the set { 𝐞_θ / √g | ρ=0 }.
13891403
data["e_theta/sqrt(g)"] = transforms["grid"].replace_at_axis(
13901404
(data["e_theta"].T / data["sqrt(g)"]).T,
13911405
lambda: (data["e_theta_r"].T / data["sqrt(g)_r"]).T,
@@ -1426,6 +1440,8 @@ def _e_sub_theta_pest(params, transforms, profiles, data, **kwargs):
14261440
data=["R", "R_r", "R_rt", "R_t", "Z_rt", "omega_r", "omega_rt", "omega_t"],
14271441
)
14281442
def _e_sub_theta_r(params, transforms, profiles, data, **kwargs):
1443+
# At the magnetic axis, this function returns the multivalued map whose
1444+
# image is the set { ∂ᵨ 𝐞_θ | ρ=0 }
14291445
data["e_theta_r"] = jnp.array(
14301446
[
14311447
-data["R"] * data["omega_t"] * data["omega_r"] + data["R_rt"],
@@ -3428,16 +3444,22 @@ def _gradpsi(params, transforms, profiles, data, **kwargs):
34283444
profiles=[],
34293445
coordinates="rtz",
34303446
data=["e_theta", "e_zeta", "|e_theta x e_zeta|"],
3447+
axis_limit_data=["e_theta_r", "|e_theta x e_zeta|_r"],
34313448
parameterization=[
34323449
"desc.equilibrium.equilibrium.Equilibrium",
34333450
"desc.geometry.core.Surface",
34343451
],
34353452
)
34363453
def _n_rho(params, transforms, profiles, data, **kwargs):
3437-
# equal to e^rho / |e^rho| but works correctly for surfaces as well that don't have
3438-
# contravariant basis defined
3439-
data["n_rho"] = (
3440-
cross(data["e_theta"], data["e_zeta"]) / data["|e_theta x e_zeta|"][:, None]
3454+
# Equal to 𝐞^ρ / ‖𝐞^ρ‖ but works correctly for surfaces as well that don't
3455+
# have contravariant basis defined.
3456+
data["n_rho"] = transforms["grid"].replace_at_axis(
3457+
(cross(data["e_theta"], data["e_zeta"]).T / data["|e_theta x e_zeta|"]).T,
3458+
# At the magnetic axis, this function returns the multivalued map whose
3459+
# image is the set { 𝐞^ρ / ‖𝐞^ρ‖ | ρ=0 }.
3460+
lambda: (
3461+
cross(data["e_theta_r"], data["e_zeta"]).T / data["|e_theta x e_zeta|_r"]
3462+
).T,
34413463
)
34423464
return data
34433465

@@ -3460,9 +3482,11 @@ def _n_rho(params, transforms, profiles, data, **kwargs):
34603482
],
34613483
)
34623484
def _n_theta(params, transforms, profiles, data, **kwargs):
3485+
# Equal to 𝐞^θ / ‖𝐞^θ‖ but works correctly for surfaces as well that don't
3486+
# have contravariant basis defined.
34633487
data["n_theta"] = (
3464-
cross(data["e_zeta"], data["e_rho"]) / data["|e_zeta x e_rho|"][:, None]
3465-
)
3488+
cross(data["e_zeta"], data["e_rho"]).T / data["|e_zeta x e_rho|"]
3489+
).T
34663490
return data
34673491

34683492

@@ -3478,13 +3502,21 @@ def _n_theta(params, transforms, profiles, data, **kwargs):
34783502
profiles=[],
34793503
coordinates="rtz",
34803504
data=["e_rho", "e_theta", "|e_rho x e_theta|"],
3505+
axis_limit_data=["e_theta_r", "|e_rho x e_theta|_r"],
34813506
parameterization=[
34823507
"desc.equilibrium.equilibrium.Equilibrium",
34833508
"desc.geometry.core.Surface",
34843509
],
34853510
)
34863511
def _n_zeta(params, transforms, profiles, data, **kwargs):
3487-
data["n_zeta"] = (
3488-
cross(data["e_rho"], data["e_theta"]) / data["|e_rho x e_theta|"][:, None]
3512+
# Equal to 𝐞^ζ / ‖𝐞^ζ‖ but works correctly for surfaces as well that don't
3513+
# have contravariant basis defined.
3514+
data["n_zeta"] = transforms["grid"].replace_at_axis(
3515+
(cross(data["e_rho"], data["e_theta"]).T / data["|e_rho x e_theta|"]).T,
3516+
# At the magnetic axis, this function returns the multivalued map whose
3517+
# image is the set { 𝐞^ζ / ‖𝐞^ζ‖ | ρ=0 }.
3518+
lambda: (
3519+
cross(data["e_rho"], data["e_theta_r"]).T / data["|e_rho x e_theta|_r"]
3520+
).T,
34893521
)
34903522
return data

0 commit comments

Comments
 (0)