Skip to content

py-econometrics/jaxonometrics

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jaxonometrics: Econometrics in jax

Tests Fixed Effects Tests

Simple library that provides performant implementations of standard econometrics routines in the JAX ecosystem.

  • jax arrays everywhere
  • lineax for solving linear systems
  • jaxopt and optax for numerical optimization (Levenberg–Marquardt for NNLS-type problems and SGD for larger problems)

Features

  • Linear Regression with multiple solver backends (lineax, JAX, numpy)
  • Fixed Effects Regression with JAX-accelerated alternating projections
  • GMM and IV Estimation
  • Causal Inference (IPW, AIPW, Entropy Balancing)
  • Maximum Likelihood Estimation (Logistic, Poisson)

Fixed Effects

jaxonometrics supports high-performance fixed effects regression with multiple FE variables:

from jaxonometrics import LinearRegression
import jax.numpy as jnp

# Your data
X = jnp.asarray(data)  # (n_obs, n_features)
y = jnp.asarray(target)  # (n_obs,)
firm_ids = jnp.asarray(firm_identifiers, dtype=jnp.int32)
year_ids = jnp.asarray(year_identifiers, dtype=jnp.int32)

# Two-way fixed effects
model = LinearRegression(solver="lineax")
model.fit(X, y, fe=[firm_ids, year_ids])
coefficients = model.params["coef"]

Installation and Development

Install

uv pip install git+https://github.com/py-econometrics/jaxonometrics

or clone the repository and install in editable mode.

Testing

Run the full test suite:

pytest tests/ -v

Run only fixed effects tests:

pytest tests/ -m fe -v

Run tests excluding slow ones:

pytest tests/ -m "not slow" -v

About

Econometrics on the GPU (and CPU) via JAX

Resources

Stars

Watchers

Forks

Contributors 2

  •  
  •