Skip to content

[Feature Request] Envpool XLA and TorchRL #3164

@AdrianOrenstein

Description

@AdrianOrenstein

Motivation

EnvPool have an experimental XLA wrapper, this can be used for Atari games. This allows for the execution of many Atari environments in parallel on an Nvidia accelerator.

From https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool_xla_jax.py#L201C5-L220C46:

# env setup
envs = envpool.make(
    "ALE/Pong-v5",
    env_type="gym",
    num_envs=4,
    episodic_life=False,
    reward_clip=False,
    seed=42,
)
envs.num_envs = args.num_envs
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs.is_vector_env = True
handle, recv, send, step_env = envs.xla()

Solution

Pytorch has xla support, via openxla, so in theory, torchRL is compatible with executing an Atari enviroment on the Accelerator.

From https://docs.pytorch.org/xla/master/torch_compile.html:

import torch
import torch_xla.core.xla_model as xm

def add(a, b):
  a_xla = a.to(xm.xla_device())
  b_xla = b.to(xm.xla_device())
  return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')

The EnvPool wrapper in torchRL could be modified to generate environments with .xla() called if a cuda device is passed into the wrapper.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions