Skip to content

Commit 15ff85b

Browse files
committed
Update
[ghstack-poisoned]
2 parents 7dab9d1 + 7b7fd78 commit 15ff85b

File tree

13 files changed

+116
-72
lines changed

13 files changed

+116
-72
lines changed

.github/unittest/linux_libs/scripts_llm/install.sh renamed to .github/unittest/llm/scripts_llm/install.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ git submodule sync && git submodule update --init --recursive
3030
#printf "Installing PyTorch with cu128"
3131
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
3232
# if [ "${CU_VERSION:-}" == cpu ] ; then
33-
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
33+
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
3434
# else
35-
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35+
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
3636
# fi
3737
#elif [[ "$TORCH_VERSION" == "stable" ]]; then
3838
# if [ "${CU_VERSION:-}" == cpu ] ; then
39-
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
39+
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
4040
# else
41-
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
41+
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
4242
# fi
4343
#else
4444
# printf "Failed to install pytorch"
@@ -47,9 +47,10 @@ git submodule sync && git submodule update --init --recursive
4747

4848
# install tensordict
4949
if [[ "$RELEASE" == 0 ]]; then
50-
pip3 install git+https://github.com/pytorch/tensordict.git
50+
pip install "pybind11[global]" ninja
51+
pip install git+https://github.com/pytorch/tensordict.git
5152
else
52-
pip3 install tensordict
53+
pip install tensordict
5354
fi
5455

5556
# smoke test

.github/unittest/linux_libs/scripts_llm/run_test.sh renamed to .github/unittest/llm/scripts_llm/run_test.sh

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,4 @@ lib_dir="${env_dir}/lib"
2323

2424
conda deactivate && conda activate ./env
2525

26-
python -c "import transformers, datasets"
27-
28-
pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips
29-
30-
python examples/rlhf/train_rlhf.py \
31-
sys.device=cuda:0 sys.ref_device=cuda:0 \
32-
model.name_or_path=gpt2 train.max_epochs=2 \
33-
data.batch_size=2 train.ppo.ppo_batch_size=2 \
34-
train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \
35-
train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \
36-
data.block_size=110 io.logger=csv
26+
pytest test/llm -vvv --instafail --durations 600 --capture no --error-for-skips

.github/unittest/linux_libs/scripts_llm/setup_env.sh renamed to .github/unittest/llm/scripts_llm/setup_env.sh

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,19 @@
66
# Do not install PyTorch and torchvision here, otherwise they also get cached.
77

88
set -e
9-
apt-get update && apt-get upgrade -y && apt-get install -y git cmake
9+
export DEBIAN_FRONTEND=noninteractive
10+
export TZ=UTC
11+
apt-get update
12+
apt-get install -yq --no-install-recommends git wget unzip curl patchelf
1013
# Avoid error: "fatal: unsafe repository"
1114
git config --global --add safe.directory '*'
12-
apt-get install -y wget \
13-
gcc \
14-
g++ \
15-
unzip \
16-
curl \
17-
patchelf \
18-
libosmesa6-dev \
19-
libgl1-mesa-glx \
20-
libglfw3 \
21-
swig3.0 \
22-
libglew-dev \
23-
libglvnd0 \
24-
libgl1 \
25-
libglx0 \
26-
libegl1 \
27-
libgles2
15+
# The base PyTorch devel image provides compilers, CMake >= 3.22, and most build deps.
16+
# Install only minimal utilities not guaranteed to be present.
2817

29-
# Upgrade specific package
30-
apt-get upgrade -y libstdc++6
18+
# CMake available in the PyTorch devel image (Ubuntu 22.04) is sufficient.
19+
20+
# Cleanup APT cache
21+
apt-get clean && rm -rf /var/lib/apt/lists/*
3122

3223
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
3324
root_dir="$(git rev-parse --show-toplevel)"

.github/workflows/test-linux-llm.yml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@ permissions:
2121

2222
jobs:
2323
unittests:
24+
if: ${{ github.event_name == 'push' || (github.event_name == 'pull_request' && contains(join(github.event.pull_request.labels.*.name, ', '), 'llm/')) }}
2425
strategy:
2526
matrix:
26-
python_version: ["3.9"]
27-
cuda_arch_version: ["12.8"]
27+
python_version: ["3.12"]
28+
cuda_arch_version: ["12.9"]
2829
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2930
with:
3031
repository: pytorch/rl
31-
runner: "linux.g5.4xlarge.nvidia.gpu"
32+
runner: "linux.g6.4xlarge.experimental.nvidia.gpu"
3233
# gpu-arch-type: cuda
3334
# gpu-arch-version: "11.7"
34-
docker-image: "nvidia/cudagl:11.4.0-base"
35+
docker-image: "pytorch/pytorch:2.8.0-cuda12.9-cudnn9-devel"
3536
timeout: 120
3637
script: |
3738
if [[ "${{ github.ref }}" =~ release/* ]]; then
@@ -43,14 +44,14 @@ jobs:
4344
fi
4445
4546
set -euo pipefail
46-
export PYTHON_VERSION="3.9"
47-
export CU_VERSION="cu117"
47+
export PYTHON_VERSION="3.12"
48+
export CU_VERSION="cu129"
4849
export TAR_OPTIONS="--no-same-owner"
4950
export UPLOAD_CHANNEL="nightly"
5051
export TF_CPP_MIN_LOG_LEVEL=0
5152
export TD_GET_DEFAULTS_TO_NONE=1
5253
53-
bash .github/unittest/linux_libs/scripts_llm/setup_env.sh
54-
bash .github/unittest/linux_libs/scripts_llm/install.sh
55-
bash .github/unittest/linux_libs/scripts_llm/run_test.sh
56-
bash .github/unittest/linux_libs/scripts_llm/post_process.sh
54+
bash .github/unittest/llm/scripts_llm/setup_env.sh
55+
bash .github/unittest/llm/scripts_llm/install.sh
56+
bash .github/unittest/llm/scripts_llm/run_test.sh
57+
bash .github/unittest/llm/scripts_llm/post_process.sh

test/llm/test_updaters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
5+
from __future__ import annotations
66

77
import argparse
88
import gc

test/test_collector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward(self, observation):
162162
output = self.linear(observation)
163163
if self.multiple_outputs:
164164
return output, output.sum(), output.min(), output.max()
165-
return self.linear(observation)
165+
return output
166166

167167

168168
class UnwrappablePolicy(nn.Module):
@@ -1512,6 +1512,7 @@ def create_env():
15121512
cudagraph_policy=cudagraph,
15131513
weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()},
15141514
)
1515+
assert "policy" in collector._weight_senders, collector._weight_senders.keys()
15151516
try:
15161517
# collect state_dict
15171518
state_dict = collector.state_dict()

test/test_env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3836,6 +3836,8 @@ def test_parallel(self, bwad, use_buffers, maybe_fork_ParallelEnv):
38363836
finally:
38373837
env.close(raise_if_closed=False)
38383838
del env
3839+
time.sleep(0.1)
3840+
gc.collect()
38393841

38403842
class AddString(Transform):
38413843
def __init__(self):

0 commit comments

Comments
 (0)