Skip to content

An MJX-style JAX FFI for CPU-based MuJoCo #2813

@hartikainen

Description

@hartikainen

The feature, motivation and pitch

I am training reinforcement learning agents in Jax and currently use MJX as the simulator. The MJX interface is really nice to work with, as it allows me to reason about a single environment at a time using familiar MuJoCo syntax while leveraging jax.jit and jax.vmap for vectorization and compilation. This results in both nice-to-author and, in most cases, performant environments.

However, while MJX performs well for simpler environments, its speed scales poorly for more complex scenes with many objects and contacts. This limitation manifests not only in performance (e.g., poor collision scaling) but also in usability. The Jax/MJX compilation is often slow, which makes debugging, visualization, and testing quite tedious compared to iterating with CPU-based MuJoCo.

I think a promising solution could be to introduce a clean, vectorized interface for CPU MuJoCo that mirrors the MJX API. This could be achieved by using Jax FFI to efficiently manage and step a batch of mujoco.Data instances. The ideal interface would expose functions like mjx.step and mjx.forward and data structures like mjx.Data and mjx.Model, but with the simulation backend being the CPU-based MuJoCo.

I think this feature could provide the best of both MJX and MuJoCo worlds:

  • The clean, single-program, multiple-data (SPMD) experience authoring environments (like MJX enables via jax.{jit,vmap}).
  • The feature-completeness, scalability, and user-friendliness of the CPU-based MuJoCo.

Alternatives

EnvPool

EnvPool is quite similar in spirit and achieves pretty impressive throughput on CPU. It provides an Jax FFI interface for the environment but it requires the environment logic itself to be written in C++. What I'm requesting here is a bit different in that I want define the environment logic in Python/Jax, leveraging the existing MJX/Jax patterns for clean vectorization via jax.{vmap,jit}. EnvPool also appears to be largely unmaintained.

mujoco.rollout

One could theoretically combine mjx.{Model,Data} with mujoco.rollout to achieve a similar outcome. However, my attempts so far have been a bit clumsy and results in unnecessary overhead from Python loops and data marshalling between JAX and NumPy. An optimized FFI-based implementation could probably eliminate these bottlenecks and make the user experience cleaner.

For example, I’ve experimented with something like this:

Details
model: mujoco.Model = ...
datas = [mujoco.Data(model) for _ in range(batch_size)]
modelx = mjx.put_model(model)
datax: mjx.Data = ...  # batched data

# This rollout happens on the CPU via Python looping
states, _ = rollout.rollout(
    model,
    datas,
    state,
    control,
    nstep=num_substeps,
)

# Data must be manually transferred back to the JAX side
state = states[:, -1, :]
datax = mjx_set_state(modelx, datax, state)  # hypothetical function
datax = mjx.forward(modelx, datax)
...

Technically, the above allows me to write the environments in Jax but it’s still a bit clumsy. A native FFI interface would make this setup whole a lot cleaner.

MjWarp

MjWarp seems promising in improving both collision scaling and compilation/runtime speeds. However, it is still in beta, its final performance/scalability are unclear, and it does not yet have full feature parity with vanilla MuJoCo. I think a vectorized CPU backend would be a valuable, complementary tool, enabling the stability, feature set, and the usability of MuJoco. In an ideal world one could author environments in Jax and then swap the backend between MjWarp and CPU-based MuJoCo with barely any changes to the code.

Additional context

I hope this feature request is not completely out of left field! I think it would be a valuable tool for the community and I'd love to hear what you all think. I’d also be interested in contributing to the implementation if others think it would be a good fit for MuJoCo.

Metadata

Metadata

Assignees

No one assigned

    Labels

    MJXUsing JAX to run on GPUenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions