File tree Expand file tree Collapse file tree 2 files changed +3
-5
lines changed Expand file tree Collapse file tree 2 files changed +3
-5
lines changed Original file line number Diff line number Diff line change 2
2
# SPDX-License-Identifier: Apache-2.0
3
3
4
4
import jax
5
- from jax ._src .pjit import pjit_p
6
5
from jax .api_util import flatten_fun , shaped_abstractify
7
6
import jax .core as core
7
+ from jax .experimental .pjit import pjit_p
8
8
from jax .interpreters .partial_eval import trace_to_jaxpr_dynamic
9
9
from jax .interpreters .pxla import xla_pmap_p
10
- from jax .interpreters .xla import xla_call_p
11
10
import jax .linear_util as lu
12
11
import jax .numpy as jnp
13
12
@@ -102,7 +101,6 @@ def track_deps_call_rule(eqn, provenance_inputs):
102
101
103
102
104
103
track_deps_rules [core .call_p ] = track_deps_call_rule
105
- track_deps_rules [xla_call_p ] = track_deps_call_rule
106
104
track_deps_rules [xla_pmap_p ] = track_deps_call_rule
107
105
108
106
Original file line number Diff line number Diff line change 9
9
from setuptools import find_packages , setup
10
10
11
11
PROJECT_PATH = os .path .dirname (os .path .abspath (__file__ ))
12
- _jax_version_constraints = ">=0.4"
13
- _jaxlib_version_constraints = ">=0.4"
12
+ _jax_version_constraints = ">=0.4.7 "
13
+ _jaxlib_version_constraints = ">=0.4.7 "
14
14
15
15
# Find version
16
16
for line in open (os .path .join (PROJECT_PATH , "numpyro" , "version.py" )):
You can’t perform that action at this time.
0 commit comments