This repository is an implementation of the Diffusion Policy for Offline RL algorithm using JAX/Flax. It also marks my first attempt at constructing a relatively complex reinforcement learning system using technologies beyond PyTorch.
Please install the following packages:
You can install them using the following command (adjust based on your environment):
pip install d4rl gym numpy jax flax optax chex distrax wandb
Execute the following commands:
git clone https://github.com/dibyaghosh/jaxrl_m.git
cd jaxrl_m
pip install -e .
Using the -e
parameter installs the package in development mode, ensuring that the project's dependency library code points to your local version of jaxrl_m.
From the project's root directory, run:
python run_<algo_name>.py
hyper_<name>.py
: Contains default hyperparameters and tuning configurations.util_<name>.py
: Includes utility functions for data loading, models, and other helper operations.model_<name>.py
: Defines the network architecture.algo_<name>.py
: Contains the core logic of the RL agent, including creation, updates, and sampling.run_<name>.py
: The entry point for running the program.xxx_test.py
: Test files.
The training process and results can be monitored on the Weights & Biases platform.
Test data on an RTX 4060 gaming laptop:
- Training Speed: Increased from ~38 iterations per second to ~650 iterations per second, marking a significant speedup.
- GPU Utilization: Risen from ~20% to ~45%, a modest increase.
- GPU Memory Usage: Grew from ~15% to ~70%, ensuring more efficient GPU resource usage.
- Thanks to JAX, Flax, Optax, Distrax, and other high-quality deep learning libraries for their elegant code and comprehensive documentation.
- Appreciation goes to jaxrl, jaxrl2, and jaxrl_m for their outstanding contributions to applying JAX/Flax in reinforcement learning.
- Special thanks to the original author of Diffusion Policy for Offline RL for providing a robust algorithm that maintained high reproducibility even after migrating frameworks.
If you encounter any issues while using this project, please feel free to submit an issue or a pull request to help improve it. Wishing you success in your reinforcement learning research!