Skip to content

Conversation

@pkim1818
Copy link
Contributor

The surface integrals routine have use the cond function
masks = cond( has_endpoint_dupe, lambda: put(masks, jnp.array([0, -1]), masks[0] | masks[-1]), lambda: masks, )
But in older versions of jax (like the ones used for the traverse builds), cond requires the parameter "operand" which are arguments to the functions. The resulting fix is

masks = cond( has_endpoint_dupe, lambda _: put(masks, jnp.array([0, -1]), masks[0] | masks[-1]), lambda _: masks, operand=None )
Which works since both lambda functions don't take in any arguments. This also works for the newer versions of jax.

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Soon as all the tests pass, feel free to merge

@codecov
Copy link

codecov bot commented Aug 24, 2023

Codecov Report

Merging #635 (c73fe0f) into master (94eef4b) will increase coverage by 0.00%.
The diff coverage is n/a.

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #635   +/-   ##
=======================================
  Coverage   94.26%   94.27%           
=======================================
  Files          78       78           
  Lines       18102    18102           
=======================================
+ Hits        17064    17065    +1     
+ Misses       1038     1037    -1     
Files Changed Coverage Δ
desc/compute/utils.py 95.60% <ø> (ø)

... and 1 file with indirect coverage changes

@f0uriest f0uriest merged commit ecc14fa into master Aug 24, 2023
@f0uriest f0uriest deleted the pk/hotfix-jax-cond branch August 24, 2023 23:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants