This repository documents my process of learning JAX, from the fundamentals concepts to implementing Deep Learning models. It contains both my personal easy to understand notes and the code for a series of progressively complex projects.
This section contains my write-ups on the core concepts of JAX, based on the official documentation and my own experiments. Also a project to build an MLP in jax based on the concepts learned.
-
Chapter 1: The Basics of JAX
- Covers the core transformations:
jit
(Just-In-Time Compilation),vmap
(Vectorization), andgrad
(Automatic Differentiation).
- Covers the core transformations:
-
Chapter 2: Core Components
- Explores essential building blocks:
jax.numpy
, explicit random numbers generation withPRNGKey
,jax.nn
for functional building blocks, andPytrees
for state management.
- Explores essential building blocks:
- Notebook
- A complete implementation of a neural network from scratch, built in a functional style using JAX. This project was a direct translation of an object-oriented Numpy implementation , demonstrating a full grasp of JAX's core concepts discused in Chapter 1 and Chapter 2.
-
Clone the repository:
git clone https://github.com/Awesome075/jax.git cd jax
-
Create a virtual environment:
python -m venv venv
-
Activate the Environment
-
On Windows:
venv\Scripts\activate
-
On Linux/macOS:
source venv/bin/activate
-
-
Install the required dependencies:
pip install -r requirements.txt
The project is organized into notebooks and markdown files. You can explore the notes
directory to follow my learning path or run the project notebooks to see the code in action.
To run the Jupyter notebooks, first set up the kernel:
python -m ipykernel install --user --name=jax-env
Then, you can run the notebook using:
jupyter notebook notes/"Phase 1"/MLP_in_JAX.ipynb
Inside the notebook, make sure to select the jax-env
kernel.
Contributions are welcome! If you have any suggestions or find any bugs, please open an issue or submit a pull request.
This project is licensed under the MIT License - see the LICENSE file for details.