Skip to content

Raise better error message for pyhf.exceptions.InvalidPdfData for JAX backend #1422

Open
@kratsg

Description

@kratsg
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> m = pyhf.simplemodels.hepdata_like([10], [15], [5])
>>> pyhf.infer.mle.fit([12.5], m)

crashes like so

Screen Shot 2021-04-29 at 1 06 17 PM

with a possible hint?

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using jnp together with import jax.numpy as jnp rather than using np via import numpy as np. If this error arises on a line that involves array indexing, like x[idx], it may be that the array being indexed x is a raw numpy.ndarray while the indices idx are a JAX Tracer instance; in that case, you can instead write jax.device_put(x)[idx].

Metadata

Metadata

Assignees

Labels

help wantedExtra attention is needed / contributions welcome

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions