Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.constants import mu_0
from termcolor import colored

from desc.backend import jnp
from desc.basis import FourierZernikeBasis, fourier, zernike_radial
from desc.compute import compute as compute_fun
from desc.compute import data_index
Expand Down Expand Up @@ -1117,7 +1118,14 @@ def R_lmn(self):

@R_lmn.setter
def R_lmn(self, R_lmn):
self._R_lmn[:] = R_lmn
R_lmn = jnp.atleast_1d(R_lmn)
errorif(
R_lmn.size != self._R_lmn.size,
ValueError,
"R_lmn should have the same size as R_basis, "
+ f"got {len(R_lmn)} for basis with {self.R_basis.num_modes} modes",
)
self._R_lmn = R_lmn

@property
def Z_lmn(self):
Expand All @@ -1126,7 +1134,14 @@ def Z_lmn(self):

@Z_lmn.setter
def Z_lmn(self, Z_lmn):
self._Z_lmn[:] = Z_lmn
Z_lmn = jnp.atleast_1d(Z_lmn)
errorif(
Z_lmn.size != self._Z_lmn.size,
ValueError,
"Z_lmn should have the same size as Z_basis, "
+ f"got {len(Z_lmn)} for basis with {self.Z_basis.num_modes} modes",
)
self._Z_lmn = Z_lmn

@property
def L_lmn(self):
Expand All @@ -1135,7 +1150,14 @@ def L_lmn(self):

@L_lmn.setter
def L_lmn(self, L_lmn):
self._L_lmn[:] = L_lmn
L_lmn = jnp.atleast_1d(L_lmn)
errorif(
L_lmn.size != self._L_lmn.size,
ValueError,
"L_lmn should have the same size as L_basis, "
+ f"got {len(L_lmn)} for basis with {self.L_basis.num_modes} modes",
)
self._L_lmn = L_lmn

@property
def Rb_lmn(self):
Expand Down
61 changes: 41 additions & 20 deletions desc/equilibrium/initial_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from desc.backend import jnp
from desc.backend import fori_loop, jit, jnp, put
from desc.basis import zernike_radial
from desc.geometry import FourierRZCurve, Surface
from desc.grid import Grid
Expand Down Expand Up @@ -220,7 +220,7 @@ def _initial_guess_surface(x_basis, b_lmn, b_basis, axis=None, mode=None, coord=

Parameters
----------
x_basis : FourierZernikeBais
x_basis : FourierZernikeBasis
basis of the flux surfaces (for R, Z, or Lambda).
b_lmn : ndarray, shape(b_basis.num_modes,)
vector of boundary coefficients associated with b_basis.
Expand All @@ -242,7 +242,11 @@ def _initial_guess_surface(x_basis, b_lmn, b_basis, axis=None, mode=None, coord=
vector of flux surface coefficients associated with x_basis.

"""
x_lmn = np.zeros((x_basis.num_modes,))
b_modes = jnp.asarray(b_basis.modes)
x_modes = jnp.asarray(x_basis.modes)
b_lmn = jnp.asarray(b_lmn)
x_lmn = jnp.zeros((x_basis.num_modes,))

if mode is None:
# auto-detect based on mode numbers
if np.all(b_basis.modes[:, 0] == 0):
Expand All @@ -256,28 +260,45 @@ def _initial_guess_surface(x_basis, b_lmn, b_basis, axis=None, mode=None, coord=
coord = 1.0
if axis is None:
axidx = np.where(b_basis.modes[:, 1] == 0)[0]
axis = np.array([b_basis.modes[axidx, 2], b_lmn[axidx]]).T
for k, (l, m, n) in enumerate(b_basis.modes):
axis = jnp.array([b_basis.modes[axidx, 2], b_lmn[axidx]]).T

# first do all the m != 0 modes, easiest since no special logic needed
def body(k, x_lmn):
l, m, n = b_modes[k]
scale = zernike_radial(coord, abs(m), m)
# index of basis mode with lowest radial power (l = |m|)
idx0 = np.where((x_basis.modes == [np.abs(m), m, n]).all(axis=1))[0]
if m == 0: # magnetic axis only affects m=0 modes
# index of basis mode with second lowest radial power (l = |m| + 2)
idx2 = np.where((x_basis.modes == [np.abs(m) + 2, m, n]).all(axis=1))[0]
ax = np.where(axis[:, 0] == n)[0]
if ax.size:
a_n = axis[ax[0], 1] # use provided axis guess
else:
a_n = b_lmn[k] # use boundary centroid as axis
x_lmn[idx0] = (b_lmn[k] + a_n) / 2 / scale
x_lmn[idx2] = (b_lmn[k] - a_n) / 2 / scale
mask0 = (x_modes == jnp.array([abs(m), m, n])).all(axis=1)
x_lmn = jnp.where(mask0, b_lmn[k] / scale, x_lmn)
return x_lmn

x_lmn = fori_loop(0, b_basis.num_modes, body, x_lmn)

# now overwrite stuff to deal with the axis
scale = zernike_radial(coord, 0, 0)
for k, (l, m, n) in enumerate(b_basis.modes):
if m != 0:
continue
# index of basis mode with lowest radial power (l = |m|)
idx0 = np.where((x_basis.modes == [abs(m), m, n]).all(axis=1))[0]
# index of basis mode with second lowest radial power (l = |m| + 2)
idx2 = np.where((x_basis.modes == [abs(m) + 2, m, n]).all(axis=1))[0]
ax = np.where(axis[:, 0] == n)[0]
if ax.size:
a_n = axis[ax[0], 1] # use provided axis guess
else:
x_lmn[idx0] = b_lmn[k] / scale
a_n = b_lmn[k] # use boundary centroid as axis
x_lmn = jit(put)(x_lmn, idx0, (b_lmn[k] + a_n) / 2 / scale)
x_lmn = jit(put)(x_lmn, idx2, (b_lmn[k] - a_n) / 2 / scale)

elif mode == "poincare":
for k, (l, m, n) in enumerate(b_basis.modes):
idx = np.where((x_basis.modes == [l, m, n]).all(axis=1))[0]
x_lmn[idx] = b_lmn[k]

def body(k, x_lmn):
l, m, n = b_modes[k]
mask0 = (x_modes == jnp.array([l, m, n])).all(axis=1)
x_lmn = jnp.where(x_lmn, mask0, b_lmn[k], x_lmn)
return x_lmn

x_lmn = fori_loop(0, b_basis.num_modes, body, x_lmn)

else:
raise ValueError("Boundary mode should be either 'lcfs' or 'poincare'.")
Expand Down
25 changes: 16 additions & 9 deletions desc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from scipy.special import factorial
from termcolor import colored

from desc.backend import fori_loop, jit, jnp


class Timer:
"""Simple object for organizing timing info.
Expand Down Expand Up @@ -361,22 +363,27 @@ def islinspaced(x, axis=-1, rtol=1e-6, atol=1e-12):
return isalmostequal(np.diff(x, axis=axis), rtol=rtol, atol=atol, axis=axis)


@jit
def copy_coeffs(c_old, modes_old, modes_new, c_new=None):
"""Copy coefficients from one resolution to another."""
modes_old, modes_new = np.atleast_1d(modes_old), np.atleast_1d(modes_new)
modes_old, modes_new = jnp.atleast_1d(modes_old), jnp.atleast_1d(modes_new)

if modes_old.ndim == 1:
modes_old = modes_old.reshape((-1, 1))
if modes_new.ndim == 1:
modes_new = modes_new.reshape((-1, 1))

num_modes = modes_new.shape[0]
if c_new is None:
c_new = np.zeros((num_modes,))
c_new = jnp.zeros((modes_new.shape[0],))
c_old, c_new = jnp.asarray(c_old), jnp.asarray(c_new)

def body(i, c_new):
mask = (modes_old[i, :] == modes_new).all(axis=1)
c_new = jnp.where(mask, c_old[i], c_new)
return c_new

for i in range(num_modes):
idx = np.where((modes_old == modes_new[i, :]).all(axis=1))[0]
if len(idx):
c_new[i] = c_old[idx]
if c_old.size:
c_new = fori_loop(0, modes_old.shape[0], body, c_new)
return c_new


Expand Down Expand Up @@ -422,7 +429,7 @@ def combination_permutation(m, n, equals=True):
n : int
Maximum sum
equals : bool
If True, return only where sum == n, else retun where sum <= n
If True, return only where sum == n, else return where sum <= n

Returns
-------
Expand Down Expand Up @@ -478,7 +485,7 @@ def get_instance(things, cls):


def parse_argname_change(arg, kwargs, oldname, newname):
"""Warn and parse arguemnts whose names have changed."""
"""Warn and parse arguments whose names have changed."""
if oldname in kwargs:
warnings.warn(
FutureWarning(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pytest

from desc.backend import put
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.equilibrium.initial_guess import _initial_guess_surface
from desc.geometry import (
Expand Down Expand Up @@ -470,9 +471,9 @@ def test_is_nested():
assert eq.is_nested(grid=grid)

eq.change_resolution(L=2, M=2)
eq.R_lmn[eq.R_basis.get_idx(L=1, M=1, N=0)] = 1
eq.R_lmn = put(eq.R_lmn, eq.R_basis.get_idx(L=1, M=1, N=0), 1)
# make unnested by setting higher order mode to same amplitude as lower order mode
eq.R_lmn[eq.R_basis.get_idx(L=2, M=2, N=0)] = 1
eq.R_lmn = put(eq.R_lmn, eq.R_basis.get_idx(L=2, M=2, N=0), 1)

assert not eq.is_nested(grid=grid)
with pytest.warns(Warning) as record:
Expand Down