Skip to content

Commit ecc14fa

Browse files
authored
Merge pull request #635 from PlasmaControl/pk/hotfix-jax-cond
Change parameters in cond for surface integrals for older jax versions (power9)
2 parents 94eef4b + c73fe0f commit ecc14fa

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

desc/compute/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,8 +825,9 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True):
825825
# previous paragraph.
826826
masks = cond(
827827
has_endpoint_dupe,
828-
lambda: put(masks, jnp.array([0, -1]), masks[0] | masks[-1]),
829-
lambda: masks,
828+
lambda _: put(masks, jnp.array([0, -1]), masks[0] | masks[-1]),
829+
lambda _: masks,
830+
operand=None,
830831
)
831832
spacing = jnp.prod(spacing, axis=1)
832833

0 commit comments

Comments
 (0)