This repository contains the code for the paper "Model-based Diffusion for Trajectory Optimization".
Model-based diffusion (MBD) is a novel diffusion-based trajectory optimization framework that employs a dynamics model to approximate the score function. MBD outperforms existing methods (including RL) in terms of sample efficiency and generalization.
To install the required packages, run the following command:
git clone --depth 1 git@github.com:LeCAR-Lab/model-based-diffusion.git
pip install -e .To run model-based diffusion to optimize a trajectory, run the following command:
cd mbd/planners
python mbd_planner.py --env_name $ENV_NAMEwhere $ENV_NAME is the name of the environment, you can choose from hopper, halfcheetah, walker2d, ant, humanoidrun, humanoidstandup, humanoidtrack, car2d, pushT.
To run model-based diffusion combined with demonstrations, run the following command:
cd mbd/planners
python mbd_planner.py --env_name $ENV_NAME --enable_demosCurrently, only the humanoidtrack, car2d support demonstrations.
To run multiple seeds, run the following command:
cd mbd/scripts
python run_mbd.py --env_name $ENV_NAMETo visualize the diffusion process, run the following command:
cd mbd/scripts
python vis_diffusion.py --env_name $ENV_NAMEPlease make sure you have run the planner first to generate the data.
To run model-based diffusion for black-box optimization, run the following command:
cd mbd/blackbox
python mbd_opt.pyTo run RL-based baselines, run the following command:
cd mbd/rl
python train_brax.py --env_name $ENV_NAMETo run other zeroth order trajectory optimization baselines, run the following command:
cd mbd/planners
python path_integral.py --env_name $ENV_NAME --mode $MODEwhere $MODE is the mode of the planner, you can choose from mppi, cem, cma-es.
- This codebase's environment and RL implementation is built on top of Brax.
@misc{pan2024modelbaseddiffusiontrajectoryoptimization,
title={Model-Based Diffusion for Trajectory Optimization},
author={Chaoyi Pan and Zeji Yi and Guanya Shi and Guannan Qu},
year={2024},
eprint={2407.01573},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2407.01573},
}