-
Notifications
You must be signed in to change notification settings - Fork 406
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Motivation
EnvPool have an experimental XLA wrapper, this can be used for Atari games. This allows for the execution of many Atari environments in parallel on an Nvidia accelerator.
From https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool_xla_jax.py#L201C5-L220C46:
# env setup
envs = envpool.make(
"ALE/Pong-v5",
env_type="gym",
num_envs=4,
episodic_life=False,
reward_clip=False,
seed=42,
)
envs.num_envs = args.num_envs
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs.is_vector_env = True
handle, recv, send, step_env = envs.xla()
Solution
Pytorch has xla support, via openxla
, so in theory, torchRL is compatible with executing an Atari enviroment on the Accelerator.
From https://docs.pytorch.org/xla/master/torch_compile.html:
import torch
import torch_xla.core.xla_model as xm
def add(a, b):
a_xla = a.to(xm.xla_device())
b_xla = b.to(xm.xla_device())
return a_xla + b_xla
compiled_code = torch.compile(add, backend='openxla')
The EnvPool wrapper in torchRL could be modified to generate environments with .xla()
called if a cuda device is passed into the wrapper.
Checklist
- I have checked that there is no similar issue in the repo (required)
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request