Skip to content

QuasisymmetryBoozer objective throwing NonConcreteBooleanIndexError when used in optimization #625

@dpanici

Description

@dpanici

Adding this code to the end of the QS optimization tutorial notebook and running:

objective = ObjectiveFunction(
    (
        QuasisymmetryBoozer(helicity=(1, eq_init.NFP)),
    )
)

eq_qs_B, result_B = eq_init.optimize(
    objective=objective,
    constraints=constraints,
    optimizer=optimizer,
    ftol=1e-2,
    xtol=1e-6,
    gtol=1e-6,
    maxiter=50,
    options={
        "perturb_options": {"order": 2, "verbose": 0},
        "solve_options": {"ftol": 1e-2, "xtol": 1e-6, "gtol": 1e-6, "verbose": 0},
    },
    copy=True,
    verbose=3,
)

throws error:

---------------------------------------------------------------------------
NonConcreteBooleanIndexError              Traceback (most recent call last)
/tmp/ipykernel_5649/1841851383.py in <module>
      5 )
      6 
----> 7 eq_qs_B, result_B = eq_init.optimize(
      8     objective=objective,
      9     constraints=constraints,

~/DESC/desc/equilibrium/equilibrium.py in optimize(self, objective, constraints, optimizer, ftol, xtol, gtol, ctol, maxiter, x_scale, options, verbose, copy)
   1739             eq = self
   1740 
-> 1741         result = optimizer.optimize(
   1742             eq,
   1743             objective,

~/DESC/desc/optimize/optimizer.py in optimize(self, eq, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options)
    187                 mode = "lsq"
    188             try:
--> 189                 objective.compile(mode, verbose)
    190             except ValueError:
    191                 objective.build(eq, verbose=verbose)

~/DESC/desc/optimize/_constraint_wrappers.py in compile(self, mode, verbose)
    133 
    134         """
--> 135         self._objective.compile(mode, verbose)
    136 
    137     def project(self, x):

~/DESC/desc/optimize/_constraint_wrappers.py in compile(self, mode, verbose)
    609 
    610         """
--> 611         self._objective.compile(mode, verbose)
    612         self._constraint.compile(mode, verbose)
    613 

~/DESC/desc/objectives/objective_funs.py in compile(self, mode, verbose)
    622         if mode in ["lsq", "all"]:
    623             timer.start("Objective compilation time")
--> 624             _ = self.compute_scaled(x, self.constants).block_until_ready()
    625             timer.stop("Objective compilation time")
    626             if verbose > 1:

    [... skipping hidden 12 frame]

~/DESC/desc/objectives/objective_funs.py in compute_scaled(self, x, constants)
    339             constants = self.constants
    340         f = jnp.concatenate(
--> 341             [
    342                 obj.compute_scaled(
    343                     *self._kwargs_to_args(kwargs, obj.args), constants=const

~/DESC/desc/objectives/objective_funs.py in <listcomp>(.0)
    340         f = jnp.concatenate(
    341             [
--> 342                 obj.compute_scaled(
    343                     *self._kwargs_to_args(kwargs, obj.args), constants=const
    344                 )

    [... skipping hidden 12 frame]

~/DESC/desc/objectives/objective_funs.py in compute_scaled(self, *args, **kwargs)
    952     def compute_scaled(self, *args, **kwargs):
    953         """Compute and apply weighting and normalization."""
--> 954         f = self.compute(*args, **kwargs)
    955         return self._scale(f)
    956 

~/DESC/desc/objectives/_qs.py in compute(self, *args, **kwargs)
    198         )
    199         B_mn = constants["matrix"] @ data["|B|_mn"]
--> 200         return B_mn[constants["idx"]]
    201 
    202     @property

~/.local/lib/python3.8/site-packages/jax/_src/numpy/array_methods.py in op(self, *args)
    789 def _forward_operator_to_aval(name):
    790   def op(self, *args):
--> 791     return getattr(self.aval, f"_{name}")(self, *args)
    792   return op
    793 

~/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4140         return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
   4141 
-> 4142   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   4143   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4144                  unique_indices, mode, fill_value)

~/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _split_index_for_jit(idx, shape)
   4218   # Expand any (concrete) boolean indices. We can then use advanced integer
   4219   # indexing logic to handle them.
-> 4220   idx = _expand_bool_indices(idx, shape)
   4221 
   4222   leaves, treedef = tree_flatten(idx)

~/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _expand_bool_indices(idx, shape)
   4532       if not type(abstract_i) is ConcreteArray:
   4533         # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
-> 4534         raise errors.NonConcreteBooleanIndexError(abstract_i)
   4535       elif _ndim(i) == 0:
   4536         raise TypeError("JAX arrays do not support boolean scalar indices")

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[281])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions