Simple library that provides performant implementations of standard econometrics routines in the JAX ecosystem.
jax
arrays everywherelineax
for solving linear systemsjaxopt
andoptax
for numerical optimization (Levenberg–Marquardt for NNLS-type problems and SGD for larger problems)
- Linear Regression with multiple solver backends (lineax, JAX, numpy)
- Fixed Effects Regression with JAX-accelerated alternating projections
- GMM and IV Estimation
- Causal Inference (IPW, AIPW, Entropy Balancing)
- Maximum Likelihood Estimation (Logistic, Poisson)
jaxonometrics supports high-performance fixed effects regression with multiple FE variables:
from jaxonometrics import LinearRegression
import jax.numpy as jnp
# Your data
X = jnp.asarray(data) # (n_obs, n_features)
y = jnp.asarray(target) # (n_obs,)
firm_ids = jnp.asarray(firm_identifiers, dtype=jnp.int32)
year_ids = jnp.asarray(year_identifiers, dtype=jnp.int32)
# Two-way fixed effects
model = LinearRegression(solver="lineax")
model.fit(X, y, fe=[firm_ids, year_ids])
coefficients = model.params["coef"]
uv pip install git+https://github.com/py-econometrics/jaxonometrics
or clone the repository and install in editable mode.
Run the full test suite:
pytest tests/ -v
Run only fixed effects tests:
pytest tests/ -m fe -v
Run tests excluding slow ones:
pytest tests/ -m "not slow" -v