Skip to content

Commit 115c4d3

Browse files
authored
Support custom prng key (#1642)
* support custom prng key * run black * test custom prng in CI * fix some deprecation warnings
1 parent ca96eca commit 115c4d3

File tree

23 files changed

+76
-54
lines changed

23 files changed

+76
-54
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ jobs:
112112
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
113113
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain"
114114
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
115+
- name: Test custom prng
116+
run: |
117+
JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py
115118
116119
117120
examples:

numpyro/contrib/einstein/steinvi.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import operator
1111

1212
from jax import grad, jacfwd, numpy as jnp, random, vmap
13-
from jax.random import KeyArray
1413
from jax.tree_util import tree_map
1514

1615
from numpyro import handlers
@@ -370,10 +369,10 @@ def _update_force(attr_force, rep_force, jac):
370369
)
371370
return jnp.linalg.norm(particle_grads), res_grads
372371

373-
def init(self, rng_key: KeyArray, *args, **kwargs):
372+
def init(self, rng_key, *args, **kwargs):
374373
"""Register random variable transformations, constraints and determine initialize positions of the particles.
375374
376-
:param KeyArray rng_key: Random number generator seed.
375+
:param rng_key: Random number generator seed.
377376
:param args: Arguments to the model / guide.
378377
:param kwargs: Keyword arguments to the model / guide.
379378
:return: initial :data:`SteinVIState`

numpyro/contrib/tfp/mcmc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numpyro.infer import init_to_uniform
1515
from numpyro.infer.mcmc import MCMCKernel
1616
from numpyro.infer.util import initialize_model
17-
from numpyro.util import identity
17+
from numpyro.util import identity, is_prng_key
1818

1919
TFPKernelState = namedtuple("TFPKernelState", ["z", "kernel_results", "rng_key"])
2020

@@ -174,7 +174,7 @@ def init(
174174
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
175175
):
176176
# non-vectorized
177-
if rng_key.ndim == 1:
177+
if is_prng_key(rng_key):
178178
rng_key, rng_key_init_model = random.split(rng_key)
179179
# vectorized
180180
else:
@@ -190,7 +190,7 @@ def init(
190190
" `target_log_prob_fn`."
191191
)
192192

193-
if rng_key.ndim == 1:
193+
if is_prng_key(rng_key):
194194
init_state = self._init_fn(init_params, rng_key)
195195
else:
196196
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some

numpyro/distributions/conjugate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
ZeroInflatedDistribution,
1515
)
1616
from numpyro.distributions.distribution import Distribution
17-
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample
17+
from numpyro.distributions.util import promote_shapes, validate_sample
18+
from numpyro.util import is_prng_key
1819

1920

2021
def _log_beta_1(alpha, value):

numpyro/distributions/continuous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@
6464
betaincinv,
6565
cholesky_of_inverse,
6666
gammaincinv,
67-
is_prng_key,
6867
lazy_property,
6968
matrix_to_tril_vec,
7069
promote_shapes,
7170
signed_stick_breaking_tril,
7271
validate_sample,
7372
vec_to_tril_matrix,
7473
)
74+
from numpyro.util import is_prng_key
7575

7676

7777
class AsymmetricLaplace(Distribution):

numpyro/distributions/copula.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@
66
import numpyro.distributions.constraints as constraints
77
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
88
from numpyro.distributions.distribution import Distribution
9-
from numpyro.distributions.util import (
10-
clamp_probs,
11-
is_prng_key,
12-
lazy_property,
13-
validate_sample,
14-
)
9+
from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample
10+
from numpyro.util import is_prng_key
1511

1612

1713
class GaussianCopula(Distribution):

numpyro/distributions/directional.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
from numpyro.distributions import constraints
1717
from numpyro.distributions.distribution import Distribution
1818
from numpyro.distributions.util import (
19-
is_prng_key,
2019
lazy_property,
2120
promote_shapes,
2221
safe_normalize,
2322
validate_sample,
2423
von_mises_centered,
2524
)
26-
from numpyro.util import while_loop
25+
from numpyro.util import is_prng_key, while_loop
2726

2827

2928
def _numel(shape):

numpyro/distributions/discrete.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,12 @@
4141
binomial,
4242
categorical,
4343
clamp_probs,
44-
is_prng_key,
4544
lazy_property,
4645
multinomial,
4746
promote_shapes,
4847
validate_sample,
4948
)
50-
from numpyro.util import not_jax_tracer
49+
from numpyro.util import is_prng_key, not_jax_tracer
5150

5251

5352
def _to_probs_bernoulli(logits):

numpyro/distributions/mixtures.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from numpyro.distributions import Distribution, constraints
99
from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs
10-
from numpyro.distributions.util import is_prng_key, validate_sample
10+
from numpyro.distributions.util import validate_sample
11+
from numpyro.util import is_prng_key
1112

1213

1314
def Mixture(mixing_distribution, component_distributions, *, validate_args=None):

numpyro/distributions/truncated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
from numpyro.distributions.distribution import Distribution
2020
from numpyro.distributions.util import (
2121
clamp_probs,
22-
is_prng_key,
2322
lazy_property,
2423
promote_shapes,
2524
validate_sample,
2625
)
26+
from numpyro.util import is_prng_key
2727

2828

2929
class LeftTruncatedDistribution(Distribution):

0 commit comments

Comments
 (0)