Skip to content

ericaweng/Joint_AgentFormer

Repository files navigation

This is a PyTorch lightning implementation of Joint AgentFormer from the paper Joint Metrics Matter, as well as AgentFormer.

Joint AgentFormer

SOTA Trajectory Forecasting baselines like AgentFormer optimize for per-agent minimum displacement error metrics such as ADE. Our method, Joint AgentFormer is optimized for multi-agent minimum displacement error metrics such as JADE -- Joint ADE.

Screenshot 2023-10-13 at 14 09 53

Training

All datasets are already pre-included in the repository in the datasets/ directory.

python pl_train.py --cfg <config>

where <config> is in the format <dset>_agentformer_pre for plain AgentFormer models and <dset>_joint_pre for Joint AgentFormer. <dset is one of eth, hotel, univ, zara1, zara2, trajnet_sdd.

After that finishes training, train the DLow model, which improves diversity of trajectory predictions:

python pl_train.py --cfg <config>

where <config> is in the format <dset>_agentformer for plain AgentFormer and <dset>_joint for Joint AgentFormer. <dset is one of eth, hotel, univ, zara1, zara2, trajnet_sdd.

Testing

We also include our own pre-trained models in the results/ directory via Git LFS. You will need to install Git LFS to download our pretrained models. After you download them, run:

python pl_train.py --cfg <config> --mode test

where <config> is the same as described in the above section.

Flag Descriptions

--cfg: name of the config file to run
--mode: either "train" "test" or "val"
--batch_size: only batch size 1 is available right now, sorry :-(
--no_gpu: specify if you want CPU-only training
--dont_resume: specify if you don't want to resume from checkpoint if it exists
--checkpoint_path: specify if you want to resume from a model different than the default (which is ./results-joint/<args.cfg>)
--save_viz: save visualizations to ./viz
--save_num: num  to visualizations save per eval step
--logs_root: default root dir to save logs and model checkpoints. default is ./results-joint and logs for a run will be saved to <args.logs_root>/<args. cfg>
--save_traj: whether to save trajectories for offline evaluation

The code is adapted for pytorch lightning (multi-gpu training) from:

AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting
Ye Yuan, Xinshuo Weng, Yanglan Ou, Kris Kitani
ICCV 2021
[website] [paper]

if you find this code useful, we would appreciate if you cite:

@misc{weng2023joint,
      title={Joint Metrics Matter: A Better Standard for Trajectory Forecasting}, 
      author={Erica Weng and Hana Hoshino and Deva Ramanan and Kris Kitani},
      year={2023},
      eprint={2305.06292},
      archivePrefix={arXiv},
      primaryClass={cs.RO}
}

@misc{yuan2021agentformer,
      title={AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting}, 
      author={Ye Yuan and Xinshuo Weng and Yanglan Ou and Kris Kitani},
      year={2021},
      eprint={2103.14023},
      archivePrefix={arXiv},
      primaryClass={cs.AI}
}

About

PyTorch Lightning implementation of Joint AgentFormer from Joint Metrics Matter (https://arxiv.org/abs/2305.06292) and AgentFormer (https://arxiv.org/abs/2103.14023)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages