Skip to content

Commit e715881

Browse files
authored
Fix for jax>=0.4.7 (#1595)
* Fix for jax 0.4.11 * Require jax, jaxlib version >= 0.4.7
1 parent f981b29 commit e715881

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

numpyro/ops/provenance.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import jax
5-
from jax._src.pjit import pjit_p
65
from jax.api_util import flatten_fun, shaped_abstractify
76
import jax.core as core
7+
from jax.experimental.pjit import pjit_p
88
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
99
from jax.interpreters.pxla import xla_pmap_p
10-
from jax.interpreters.xla import xla_call_p
1110
import jax.linear_util as lu
1211
import jax.numpy as jnp
1312

@@ -102,7 +101,6 @@ def track_deps_call_rule(eqn, provenance_inputs):
102101

103102

104103
track_deps_rules[core.call_p] = track_deps_call_rule
105-
track_deps_rules[xla_call_p] = track_deps_call_rule
106104
track_deps_rules[xla_pmap_p] = track_deps_call_rule
107105

108106

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from setuptools import find_packages, setup
1010

1111
PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
12-
_jax_version_constraints = ">=0.4"
13-
_jaxlib_version_constraints = ">=0.4"
12+
_jax_version_constraints = ">=0.4.7"
13+
_jaxlib_version_constraints = ">=0.4.7"
1414

1515
# Find version
1616
for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")):

0 commit comments

Comments
 (0)