Skip to content

[BUG] Struggling to reproduce GSM8K GRPO Training on B200 with CUDA 12.8. #3141

@lulmer

Description

@lulmer

Describe the bug

Hello the torchrl team, I am publishing this issue because I am not able to run gsm8k GRPO training example from rl/sota-implementations/grpo fails with KeyError: 'key "full" not found in LazyStackedTensorDict' during data collection phase when training on blackwell (B200) with CUDA 12.8

To Reproduce

Steps to reproduce the behavior.

It took me a lot of trial/error to properly install the env as there is some inter-compatibility issues between the latest cuda 12.8, transformers, vllm, ray and pytorch. I try to provide an env that I think should work (but maybe it is not that good)

Create uv environment with Python 3.12

uv venv --python 3.12
source .venv/bin/activate

Install required dependencies (assuming you would install torch_rl in editable mode, etc.)

uv pip install -r requirements_grpo_blackwell.txt
cd rl
uv pip install -e . 

The requirements_grpo_blackwell.txt contains the following :

absl-py==2.3.1
accelerate==1.10.1
aiohappyeyeballs==2.6.1
aiohttp==3.12.15
aiosignal==1.4.0
airportsdata==20250811
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.10.0
astor==0.8.1
attrs==25.3.0
bitsandbytes==0.47.0
blake3==1.0.5
cachetools==6.2.0
cbor2==5.7.0
certifi==2025.8.3
cffi==1.17.1
charset-normalizer==3.4.3
click==8.2.1
cloudpickle==3.1.1
compressed-tensors==0.9.4
cupy-cuda12x==13.6.0
datasets==4.0.0
depyf==0.18.0
dill==0.3.8
diskcache==5.6.3
distro==1.9.0
dnspython==2.7.0
einops==0.8.1
email-validator==2.3.0
fastapi==0.116.1
fastapi-cli==0.0.8
fastapi-cloud-cli==0.1.5
fastrlock==0.8.3
filelock==3.19.1
frozenlist==1.7.0
fsspec==2024.6.1
gguf==0.17.1
gitdb==4.0.12
gitpython==3.1.45
googleapis-common-protos==1.70.0
grpcio==1.74.0
h11==0.16.0
hf-xet==1.1.8
httpcore==1.0.9
httptools==0.6.4
httpx==0.28.1
huggingface-hub==0.34.4
hydra-core==1.3.2
idna==3.10
importlib-metadata==8.7.0
interegular==0.3.3
jinja2==3.1.6
jiter==0.10.0
jsonschema==4.25.1
jsonschema-specifications==2025.4.1
lark==1.2.2
llguidance==0.7.30
llvmlite==0.44.0
lm-format-enforcer==0.10.12
markdown==3.8.2
markdown-it-py==4.0.0
markupsafe==2.1.5
mdurl==0.1.2
mistral-common==1.8.4
mpmath==1.3.0
msgpack==1.1.1
msgspec==0.19.0
multidict==6.6.4
multiprocess==0.70.16
nest-asyncio==1.6.0
networkx==3.3
ninja==1.13.0
numba==0.61.2
numpy==2.1.2
nvidia-cublas-cu12==12.8.3.14
nvidia-cuda-cupti-cu12==12.8.57
nvidia-cuda-nvrtc-cu12==12.8.61
nvidia-cuda-runtime-cu12==12.8.57
nvidia-cudnn-cu12==9.7.1.26
nvidia-cufft-cu12==11.3.3.41
nvidia-cufile-cu12==1.13.0.11
nvidia-curand-cu12==10.3.9.55
nvidia-cusolver-cu12==11.7.2.55
nvidia-cusparse-cu12==12.5.7.53
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.8.61
nvidia-nvtx-cu12==12.8.55
omegaconf==2.3.0
openai==1.102.0
openai-harmony==0.0.4
opencv-python-headless==4.12.0.88
opentelemetry-api==1.36.0
opentelemetry-exporter-otlp==1.36.0
opentelemetry-exporter-otlp-proto-common==1.36.0
opentelemetry-exporter-otlp-proto-grpc==1.36.0
opentelemetry-exporter-otlp-proto-http==1.36.0
opentelemetry-proto==1.36.0
opentelemetry-sdk==1.36.0
opentelemetry-semantic-conventions==0.57b0
opentelemetry-semantic-conventions-ai==0.4.13
orjson==3.11.3
outlines==0.1.11
outlines-core==0.1.26
packaging==25.0
pandas==2.3.2
partial-json-parser==0.2.1.1.post6
peft==0.17.1
pillow==11.0.0
platformdirs==4.4.0
prometheus-client==0.22.1
prometheus-fastapi-instrumentator==7.1.0
propcache==0.3.2
protobuf==6.32.0
psutil==7.0.0
py-cpuinfo==9.0.0
pyarrow==21.0.0
pybase64==1.4.2
pycountry==24.6.1
pycparser==2.22
pydantic==2.11.7
pydantic-core==2.33.2
pydantic-extra-types==2.10.5
pygments==2.19.2
python-dateutil==2.9.0.post0
python-dotenv==1.1.1
python-json-logger==3.3.0
python-multipart==0.0.20
pytz==2025.2
pyvers==0.1.0
pyyaml==6.0.2
pyzmq==27.0.2
ray==2.46.0
referencing==0.36.2
regex==2025.7.34
requests==2.32.5
rich==14.1.0
rich-toolkit==0.15.0
rignore==0.6.4
rpds-py==0.27.1
safetensors==0.6.2
scipy==1.16.1
sentencepiece==0.2.1
sentry-sdk==2.35.1
setproctitle==1.3.6
setuptools==79.0.1
shellingham==1.5.4
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
soundfile==0.13.1
soxr==0.5.0.post1
starlette==0.47.3
sympy==1.13.3
tensorboard==2.20.0
tensorboard-data-server==0.7.2
tensordict==0.9.1
tiktoken==0.11.0
tokenizers==0.21.4
torch==2.7.0+cu128
torchaudio==2.7.0
torchvision==0.22.0+cu128
tqdm==4.67.1
transformers==4.53.3
triton==3.3.0
typer==0.16.1
typing-extensions==4.12.2
typing-inspection==0.4.1
tzdata==2025.2
urllib3==2.5.0
uvicorn==0.35.0
uvloop==0.21.0
vllm==0.9.0
wandb==0.21.1
watchfiles==1.1.0
websockets==15.0.1
werkzeug==3.1.3
xformers==0.0.30
xgrammar==0.1.19
xxhash==3.5.0
yarl==1.20.1
zipp==3.23.0

Try to launch the gsm8k grpo script from README.md :

VLLM_USE_V1=0 python sota-implementations/grpo/grpo-sync.py mode=sync train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2

Expected behavior

The training should work

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
    The library was installed using UV
  • Python version
    3.12.11
  • Versions of any other relevant libraries
    torch 2.7.0
    ray 2.46.0
    vllm 0.9.0
    transformers 4.53.3

Checklist

  • [ X ] I have checked that there is no similar issue in the repo (required)
  • [ X ] I have read the documentation (required)
  • [ X ] I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions