Open
Description
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> m = pyhf.simplemodels.hepdata_like([10], [15], [5])
>>> pyhf.infer.mle.fit([12.5], m)
crashes like so
with a possible hint?
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using
jnp
together withimport jax.numpy as jnp
rather than usingnp
viaimport numpy as np
. If this error arises on a line that involves array indexing, likex[idx]
, it may be that the array being indexedx
is a raw numpy.ndarray while the indicesidx
are a JAX Tracer instance; in that case, you can instead writejax.device_put(x)[idx]
.