Skip to content

Commit ab77747

Browse files
authored
Handle renaming of pjit_p to jit_p (#2052)
* Handle renaming of `pjit_p` to `jit_p` * Replace tensorflow_probability with tfp-nightly * Replace tensorflow_probability with tfp-nightly * Update jaxns upper bound * Bump jaxns version * Use different jaxns version for Python 3.9
1 parent 0d4f40c commit ab77747

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

docs/requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ funsor
44
ipython
55
jax
66
jaxlib
7-
jaxns==2.6.3
7+
jaxns==2.6.9;python_version>="3.10"
8+
jaxns==2.6.3;python_version<"3.10"
89
Jinja2
910
matplotlib
1011
multipledispatch
@@ -18,5 +19,5 @@ readthedocs-sphinx-search>=0.3.2
1819
sphinx>=5
1920
sphinx-gallery
2021
sphinx_rtd_theme
21-
tensorflow_probability
22+
tfp-nightly
2223
tqdm

numpyro/ops/provenance.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
import jax
55
from jax.api_util import flatten_fun, shaped_abstractify
6-
from jax.experimental.pjit import pjit_p
76

7+
try:
8+
from jax.experimental.pjit import pjit_p
9+
except ImportError:
10+
from jax.extend.core.primitives import jit_p as pjit_p
811
try:
912
import jax.extend.linear_util as lu
1013
except ImportError:

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@
6464
"flax",
6565
"funsor>=0.4.1",
6666
"graphviz",
67-
"jaxns>=2.6.3,<=2.6.8",
67+
"jaxns>=2.6.3,<=2.6.9",
6868
"matplotlib",
6969
"optax>=0.0.6",
7070
"pylab-sdk", # jaxns dependency
7171
"pytest-cov",
7272
"pyyaml", # flax dependency
7373
"requests", # pylab dependency
74-
"tensorflow_probability>=0.18.0",
74+
"tfp-nightly",
7575
],
7676
"examples": [
7777
"arviz",

0 commit comments

Comments
 (0)