-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Labels
enhancementNew feature or requestNew feature or request
Description
The feature, motivation and pitch
It would be very useful to support multi-GPU mjwarp simulations.
Problem
Running mjwarp simulations across multiple GPUs is currently difficult due to limitations in sharding and device handling.
Challenges
- Device-side model/data creation: The current API is not compatible with replicated or data-parallel sharding.
- Sharding awareness in JAX FFI integration: The
mjx/third_party/warp/jax_experimental/ffi.py
file (and other code locations) assumes the first visible device instead of handling different devices properly. - Collision data layout: mjwarp stores collision data in a flattened layout, making partitioning and sharding more complex.
Alternatives
No response
Additional context
No response
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request