Skip to content

qw3rtman/robust-world-model-planning

Repository files navigation

Closing the Train-Test Gap in World Models for Gradient-Based Planning

Arjun Parthasarathy*, Nimit Kalra*, Rohun Agrawal*,
Yann LeCun, Oumayma Bounou, Pavel Izmailov, Micah Goldblum

Checkpoints

Pretrained world model checkpoints are provided by DINO-WM and can be downloaded here under checkpoints.

Our Method

Online/Adversarial World Modeling Checkpoints. Below, we provide checkpoints obtained after applying our methods to the pretrained DINO-WM checkpoints. These correspond to the results in Table 1. We recommend using gdown to download these to your machine.

Method PushT PointMaze Wall
Online pusht.online.6000 pointmaze.online.100 wall.online.full
Adversarial pusht.adversarial.full pointmaze.adversarial.full wall.adversarial.full

Installation

Our code is adapted from DINO-WM. Please refer to their repo to any additional installation and setup instructions. See here for Modal specific instructions.

First clone the repo and create a Python environment for dependencies.

git clone https://github.com/qw3rtman/robust-world-model-planning.git
cd robust-world-model-planning
conda env create -f environment.yaml
conda activate robust_wm

Then, install Mujoco.

mkdir -p ~/.mujoco
wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -P ~/.mujoco/
cd ~/.mujoco
tar -xzvf mujoco210-linux-x86_64.tar.gz
# Mujoco Path. Replace `<username>` with your actual username if necessary.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/<username>/.mujoco/mujoco210/bin

# NVIDIA Library Path (if using NVIDIA GPUs)
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia

Notes:

  • For GPU-accelerated simulations, ensure the NVIDIA drivers are correctly installed.
  • If you encounter issues, confirm that the paths in your LD_LIBRARY_PATH are correct.
  • If problems persist, refer to these GitHub issue pages for potential solutions: openai/mujoco-py#773, ethz-asl/reinmav-gym#35.

Datasets

We use training data collected by Zhou et al. in DINO-WM here. Once the datasets are downloaded, unzip them.

Set an environment variable pointing to your dataset folder:

# Replace /path/to/data with the actual path to your dataset folder.
export DATASET_DIR=/path/to/data

Setup the dataset folder with the following structure:

data
├── point_maze
├── pusht_noise
└── wall_single

Training Robust World Models

To finetune a base world model with either Online World Modeling or Adversarial World Modeling, run train.py with the appropriate overrides for the environment and method:

# PushT
python train.py --config-name train.yaml env=pusht ckpt_path=./outputs/pusht/ method=online
python train.py --config-name train.yaml env=pusht ckpt_path=./outputs/pusht/ method=adversarial

# PointMaze
python train.py --config-name train.yaml env=point_maze ckpt_path=./outputs/point_maze/ method=online
python train.py --config-name train.yaml env=point_maze ckpt_path=./outputs/point_maze/ method=adversarial

# Wall
python train.py --config-name train.yaml env=wall ckpt_path=./outputs/wall_single/ num_hist=1 method=online
python train.py --config-name train.yaml env=wall ckpt_path=./outputs/wall_single/ num_hist=1 method=adversarial

Note that the base checkpoint used will be <ckpt_path>/checkpoints/model_latest.pth. During finetuning, checkpoints will be saved to <ckpt_path>/<method>/<date>/<time>/checkpoints.

Planning with Robust World Models

To plan with a finetuned world model, run plan.py with the appropriate config file for the environment. Set ckpt_path to the path that points to the parent directory of the checkpoints directory containing the checkpoint you want to use.

# PushT
python plan.py --config-name plan_pusht.yaml ckpt_path=./outputs/pusht/<method>/<date>/<time>/

# PointMaze
python plan.py --config-name plan_point_maze.yaml ckpt_path=./outputs/point_maze/<method>/<date>/<time>/

# Wall
python plan.py --config-name plan_wall.yaml ckpt_path=./outputs/wall_single/<method>/<date>/<time>/

Planning results and visualizations will be saved to plan_outputs/<env>/<current_time>.

For using our paper's checkpoints, replace ckpt_path with ckpt_path=./outputs/<env>/<method>.

Citation

@article{parthasarathy2025closing,
    title   = {Closing the Train–Test Gap in World Models for Gradient-Based Planning},
    author  = {Arjun Parthasarathy, Nimit Kalra, Rohun Agrawal, Yann LeCun, Oumayma Bounou, Pavel Izmailov, Micah Goldblum},
    journal = {arXiv preprint arXiv:2512.09929},
    url     = {https://arxiv.org/abs/2512.09929},
    year    = {2025}
}

About

Code for "Closing the Train-Test Gap in World Models for Gradient-Based Planning"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 5

Languages