-
Notifications
You must be signed in to change notification settings - Fork 406
Description
Describe the bug
Hello the torchrl team, I am publishing this issue because I am not able to run gsm8k GRPO training example from rl/sota-implementations/grpo
fails with KeyError: 'key "full" not found in LazyStackedTensorDict'
during data collection phase when training on blackwell (B200) with CUDA 12.8
To Reproduce
Steps to reproduce the behavior.
It took me a lot of trial/error to properly install the env as there is some inter-compatibility issues between the latest cuda 12.8, transformers, vllm, ray and pytorch. I try to provide an env that I think should work (but maybe it is not that good)
Create uv environment with Python 3.12
uv venv --python 3.12
source .venv/bin/activate
Install required dependencies (assuming you would install torch_rl
in editable mode, etc.)
uv pip install -r requirements_grpo_blackwell.txt
cd rl
uv pip install -e .
The requirements_grpo_blackwell.txt
contains the following :
absl-py==2.3.1
accelerate==1.10.1
aiohappyeyeballs==2.6.1
aiohttp==3.12.15
aiosignal==1.4.0
airportsdata==20250811
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.10.0
astor==0.8.1
attrs==25.3.0
bitsandbytes==0.47.0
blake3==1.0.5
cachetools==6.2.0
cbor2==5.7.0
certifi==2025.8.3
cffi==1.17.1
charset-normalizer==3.4.3
click==8.2.1
cloudpickle==3.1.1
compressed-tensors==0.9.4
cupy-cuda12x==13.6.0
datasets==4.0.0
depyf==0.18.0
dill==0.3.8
diskcache==5.6.3
distro==1.9.0
dnspython==2.7.0
einops==0.8.1
email-validator==2.3.0
fastapi==0.116.1
fastapi-cli==0.0.8
fastapi-cloud-cli==0.1.5
fastrlock==0.8.3
filelock==3.19.1
frozenlist==1.7.0
fsspec==2024.6.1
gguf==0.17.1
gitdb==4.0.12
gitpython==3.1.45
googleapis-common-protos==1.70.0
grpcio==1.74.0
h11==0.16.0
hf-xet==1.1.8
httpcore==1.0.9
httptools==0.6.4
httpx==0.28.1
huggingface-hub==0.34.4
hydra-core==1.3.2
idna==3.10
importlib-metadata==8.7.0
interegular==0.3.3
jinja2==3.1.6
jiter==0.10.0
jsonschema==4.25.1
jsonschema-specifications==2025.4.1
lark==1.2.2
llguidance==0.7.30
llvmlite==0.44.0
lm-format-enforcer==0.10.12
markdown==3.8.2
markdown-it-py==4.0.0
markupsafe==2.1.5
mdurl==0.1.2
mistral-common==1.8.4
mpmath==1.3.0
msgpack==1.1.1
msgspec==0.19.0
multidict==6.6.4
multiprocess==0.70.16
nest-asyncio==1.6.0
networkx==3.3
ninja==1.13.0
numba==0.61.2
numpy==2.1.2
nvidia-cublas-cu12==12.8.3.14
nvidia-cuda-cupti-cu12==12.8.57
nvidia-cuda-nvrtc-cu12==12.8.61
nvidia-cuda-runtime-cu12==12.8.57
nvidia-cudnn-cu12==9.7.1.26
nvidia-cufft-cu12==11.3.3.41
nvidia-cufile-cu12==1.13.0.11
nvidia-curand-cu12==10.3.9.55
nvidia-cusolver-cu12==11.7.2.55
nvidia-cusparse-cu12==12.5.7.53
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.8.61
nvidia-nvtx-cu12==12.8.55
omegaconf==2.3.0
openai==1.102.0
openai-harmony==0.0.4
opencv-python-headless==4.12.0.88
opentelemetry-api==1.36.0
opentelemetry-exporter-otlp==1.36.0
opentelemetry-exporter-otlp-proto-common==1.36.0
opentelemetry-exporter-otlp-proto-grpc==1.36.0
opentelemetry-exporter-otlp-proto-http==1.36.0
opentelemetry-proto==1.36.0
opentelemetry-sdk==1.36.0
opentelemetry-semantic-conventions==0.57b0
opentelemetry-semantic-conventions-ai==0.4.13
orjson==3.11.3
outlines==0.1.11
outlines-core==0.1.26
packaging==25.0
pandas==2.3.2
partial-json-parser==0.2.1.1.post6
peft==0.17.1
pillow==11.0.0
platformdirs==4.4.0
prometheus-client==0.22.1
prometheus-fastapi-instrumentator==7.1.0
propcache==0.3.2
protobuf==6.32.0
psutil==7.0.0
py-cpuinfo==9.0.0
pyarrow==21.0.0
pybase64==1.4.2
pycountry==24.6.1
pycparser==2.22
pydantic==2.11.7
pydantic-core==2.33.2
pydantic-extra-types==2.10.5
pygments==2.19.2
python-dateutil==2.9.0.post0
python-dotenv==1.1.1
python-json-logger==3.3.0
python-multipart==0.0.20
pytz==2025.2
pyvers==0.1.0
pyyaml==6.0.2
pyzmq==27.0.2
ray==2.46.0
referencing==0.36.2
regex==2025.7.34
requests==2.32.5
rich==14.1.0
rich-toolkit==0.15.0
rignore==0.6.4
rpds-py==0.27.1
safetensors==0.6.2
scipy==1.16.1
sentencepiece==0.2.1
sentry-sdk==2.35.1
setproctitle==1.3.6
setuptools==79.0.1
shellingham==1.5.4
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
soundfile==0.13.1
soxr==0.5.0.post1
starlette==0.47.3
sympy==1.13.3
tensorboard==2.20.0
tensorboard-data-server==0.7.2
tensordict==0.9.1
tiktoken==0.11.0
tokenizers==0.21.4
torch==2.7.0+cu128
torchaudio==2.7.0
torchvision==0.22.0+cu128
tqdm==4.67.1
transformers==4.53.3
triton==3.3.0
typer==0.16.1
typing-extensions==4.12.2
typing-inspection==0.4.1
tzdata==2025.2
urllib3==2.5.0
uvicorn==0.35.0
uvloop==0.21.0
vllm==0.9.0
wandb==0.21.1
watchfiles==1.1.0
websockets==15.0.1
werkzeug==3.1.3
xformers==0.0.30
xgrammar==0.1.19
xxhash==3.5.0
yarl==1.20.1
zipp==3.23.0
Try to launch the gsm8k grpo script from README.md :
VLLM_USE_V1=0 python sota-implementations/grpo/grpo-sync.py mode=sync train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
Expected behavior
The training should work
Screenshots
If applicable, add screenshots to help explain your problem.
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
The library was installed using UV - Python version
3.12.11 - Versions of any other relevant libraries
torch 2.7.0
ray 2.46.0
vllm 0.9.0
transformers 4.53.3
Checklist
- [ X ] I have checked that there is no similar issue in the repo (required)
- [ X ] I have read the documentation (required)
- [ X ] I have provided a minimal working example to reproduce the bug (required)