diff --git a/.github/unittest/linux_libs/scripts_llm/environment.yml b/.github/unittest/llm/scripts_llm/environment.yml similarity index 100% rename from .github/unittest/linux_libs/scripts_llm/environment.yml rename to .github/unittest/llm/scripts_llm/environment.yml diff --git a/.github/unittest/linux_libs/scripts_llm/install.sh b/.github/unittest/llm/scripts_llm/install.sh similarity index 73% rename from .github/unittest/linux_libs/scripts_llm/install.sh rename to .github/unittest/llm/scripts_llm/install.sh index 5f742827573..f5d23d2afbd 100644 --- a/.github/unittest/linux_libs/scripts_llm/install.sh +++ b/.github/unittest/llm/scripts_llm/install.sh @@ -30,15 +30,15 @@ git submodule sync && git submodule update --init --recursive #printf "Installing PyTorch with cu128" #if [[ "$TORCH_VERSION" == "nightly" ]]; then # if [ "${CU_VERSION:-}" == cpu ] ; then -# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U +# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U # else -# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U +# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U # fi #elif [[ "$TORCH_VERSION" == "stable" ]]; then # if [ "${CU_VERSION:-}" == cpu ] ; then -# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu +# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu # else -# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128 +# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128 # fi #else # printf "Failed to install pytorch" @@ -47,9 +47,10 @@ git submodule sync && git submodule update --init --recursive # install tensordict if [[ "$RELEASE" == 0 ]]; then - pip3 install git+https://github.com/pytorch/tensordict.git + pip install "pybind11[global]" ninja + pip install git+https://github.com/pytorch/tensordict.git else - pip3 install tensordict + pip install tensordict fi # smoke test diff --git a/.github/unittest/linux_libs/scripts_llm/post_process.sh b/.github/unittest/llm/scripts_llm/post_process.sh similarity index 100% rename from .github/unittest/linux_libs/scripts_llm/post_process.sh rename to .github/unittest/llm/scripts_llm/post_process.sh diff --git a/.github/unittest/linux_libs/scripts_llm/run-clang-format.py b/.github/unittest/llm/scripts_llm/run-clang-format.py similarity index 100% rename from .github/unittest/linux_libs/scripts_llm/run-clang-format.py rename to .github/unittest/llm/scripts_llm/run-clang-format.py diff --git a/.github/unittest/linux_libs/scripts_llm/run_test.sh b/.github/unittest/llm/scripts_llm/run_test.sh similarity index 59% rename from .github/unittest/linux_libs/scripts_llm/run_test.sh rename to .github/unittest/llm/scripts_llm/run_test.sh index ac60ae37f1e..bf811b01eb6 100644 --- a/.github/unittest/linux_libs/scripts_llm/run_test.sh +++ b/.github/unittest/llm/scripts_llm/run_test.sh @@ -23,14 +23,4 @@ lib_dir="${env_dir}/lib" conda deactivate && conda activate ./env -python -c "import transformers, datasets" - -pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips - -python examples/rlhf/train_rlhf.py \ - sys.device=cuda:0 sys.ref_device=cuda:0 \ - model.name_or_path=gpt2 train.max_epochs=2 \ - data.batch_size=2 train.ppo.ppo_batch_size=2 \ - train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \ - train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \ - data.block_size=110 io.logger=csv +pytest test/llm -vvv --instafail --durations 600 --capture no --error-for-skips diff --git a/.github/unittest/linux_libs/scripts_llm/setup_env.sh b/.github/unittest/llm/scripts_llm/setup_env.sh similarity index 76% rename from .github/unittest/linux_libs/scripts_llm/setup_env.sh rename to .github/unittest/llm/scripts_llm/setup_env.sh index 53dfc0bd50b..0245a7fe790 100644 --- a/.github/unittest/linux_libs/scripts_llm/setup_env.sh +++ b/.github/unittest/llm/scripts_llm/setup_env.sh @@ -6,28 +6,19 @@ # Do not install PyTorch and torchvision here, otherwise they also get cached. set -e -apt-get update && apt-get upgrade -y && apt-get install -y git cmake +export DEBIAN_FRONTEND=noninteractive +export TZ=UTC +apt-get update +apt-get install -yq --no-install-recommends git wget unzip curl patchelf # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' -apt-get install -y wget \ - gcc \ - g++ \ - unzip \ - curl \ - patchelf \ - libosmesa6-dev \ - libgl1-mesa-glx \ - libglfw3 \ - swig3.0 \ - libglew-dev \ - libglvnd0 \ - libgl1 \ - libglx0 \ - libegl1 \ - libgles2 +# The base PyTorch devel image provides compilers, CMake >= 3.22, and most build deps. +# Install only minimal utilities not guaranteed to be present. -# Upgrade specific package -apt-get upgrade -y libstdc++6 +# CMake available in the PyTorch devel image (Ubuntu 22.04) is sufficient. + +# Cleanup APT cache +apt-get clean && rm -rf /var/lib/apt/lists/* this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" root_dir="$(git rev-parse --show-toplevel)" diff --git a/.github/workflows/test-linux-llm.yml b/.github/workflows/test-linux-llm.yml index 4de8b8165d9..5f2c4199515 100644 --- a/.github/workflows/test-linux-llm.yml +++ b/.github/workflows/test-linux-llm.yml @@ -21,17 +21,18 @@ permissions: jobs: unittests: + if: ${{ github.event_name == 'push' || (github.event_name == 'pull_request' && contains(join(github.event.pull_request.labels.*.name, ', '), 'llm/')) }} strategy: matrix: - python_version: ["3.9"] - cuda_arch_version: ["12.8"] + python_version: ["3.12"] + cuda_arch_version: ["12.9"] uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: repository: pytorch/rl - runner: "linux.g5.4xlarge.nvidia.gpu" + runner: "linux.g6.4xlarge.experimental.nvidia.gpu" # gpu-arch-type: cuda # gpu-arch-version: "11.7" - docker-image: "nvidia/cudagl:11.4.0-base" + docker-image: "pytorch/pytorch:2.8.0-cuda12.9-cudnn9-devel" timeout: 120 script: | if [[ "${{ github.ref }}" =~ release/* ]]; then @@ -43,14 +44,14 @@ jobs: fi set -euo pipefail - export PYTHON_VERSION="3.9" - export CU_VERSION="cu117" + export PYTHON_VERSION="3.12" + export CU_VERSION="cu129" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export TD_GET_DEFAULTS_TO_NONE=1 - bash .github/unittest/linux_libs/scripts_llm/setup_env.sh - bash .github/unittest/linux_libs/scripts_llm/install.sh - bash .github/unittest/linux_libs/scripts_llm/run_test.sh - bash .github/unittest/linux_libs/scripts_llm/post_process.sh + bash .github/unittest/llm/scripts_llm/setup_env.sh + bash .github/unittest/llm/scripts_llm/install.sh + bash .github/unittest/llm/scripts_llm/run_test.sh + bash .github/unittest/llm/scripts_llm/post_process.sh diff --git a/test/llm/test_updaters.py b/test/llm/test_updaters.py index 02e2efed163..4e9c115f7ba 100644 --- a/test/llm/test_updaters.py +++ b/test/llm/test_updaters.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from __future__ import annotations import argparse import gc diff --git a/torchrl/modules/llm/backends/vllm/vllm_async.py b/torchrl/modules/llm/backends/vllm/vllm_async.py index d7435b9b99e..e13d47f1e28 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_async.py +++ b/torchrl/modules/llm/backends/vllm/vllm_async.py @@ -20,12 +20,8 @@ from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, Literal, TYPE_CHECKING -import ray - import torch -from ray.util.placement_group import placement_group, remove_placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from torchrl._utils import logger as torchrl_logger # Import RLvLLMEngine and shared utilities @@ -51,6 +47,25 @@ _has_vllm = False +def _get_ray(): + """Import Ray on demand to avoid global import side-effects. + + Returns: + ModuleType: The imported Ray module. + + Raises: + ImportError: If Ray is not installed. + """ + try: + import ray # type: ignore + + return ray + except Exception as e: # pragma: no cover - surfaced to callers + raise ImportError( + "ray is not installed. Please install it with `pip install ray`." + ) from e + + class _AsyncvLLMWorker: """Async vLLM worker extension for Ray with weight update capabilities.""" @@ -264,7 +279,7 @@ async def generate( "vllm is not installed. Please install it with `pip install vllm`." ) - from vllm import RequestOutput, SamplingParams, TokensPrompt + from vllm import SamplingParams, TokensPrompt # Track whether input was originally a single prompt single_prompt_input = False @@ -471,11 +486,7 @@ def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int: ) -# Create Ray remote versions -if ray is not None and _has_vllm: - _AsyncLLMEngineActor = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine) -else: - _AsyncLLMEngineActor = None +# Ray actor wrapper is created lazily in __init__ to avoid global Ray import. class AsyncVLLM(RLvLLMEngine): @@ -580,17 +591,18 @@ def __init__( raise ImportError( "vllm is not installed. Please install it with `pip install vllm`." ) - if ray is None: - raise ImportError( - "ray is not installed. Please install it with `pip install ray`." - ) + # Lazily import ray only when constructing the actor class to avoid global import # Enable prefix caching by default for better performance engine_args.enable_prefix_caching = enable_prefix_caching self.engine_args = engine_args self.num_replicas = num_replicas - self.actor_class = actor_class or _AsyncLLMEngineActor + if actor_class is None: + ray = _get_ray() + self.actor_class = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine) + else: + self.actor_class = actor_class self.actors: list = [] self._launched = False self._service_id = uuid.uuid4().hex[ @@ -605,6 +617,11 @@ def _launch(self): torchrl_logger.warning("AsyncVLLMEngineService already launched") return + # Local imports to avoid global Ray dependency + ray = _get_ray() + from ray.util.placement_group import placement_group + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + torchrl_logger.info( f"Launching {self.num_replicas} async vLLM engine actors..." ) @@ -944,6 +961,7 @@ def generate( Returns: RequestOutput | list[RequestOutput]: Generated outputs from vLLM. """ + ray = _get_ray() # Check if this is a batch request if self._is_batch(prompts, prompt_token_ids): # Handle batched input by unbinding and sending individual requests @@ -1068,6 +1086,9 @@ def shutdown(self): f"Shutting down {len(self.actors)} async vLLM engine actors..." ) + ray = _get_ray() + from ray.util.placement_group import remove_placement_group + # Kill all actors for i, actor in enumerate(self.actors): try: @@ -1260,6 +1281,7 @@ def _update_weights_with_nccl_broadcast_simple( ) updated_weights = 0 + ray = _get_ray() with torch.cuda.device(0): # Ensure we're on the correct CUDA device for name, weight in gpu_weights.items(): # Convert dtype to string name (like periodic-mono) @@ -1336,6 +1358,7 @@ def get_num_unfinished_requests( "AsyncVLLM service must be launched before getting request counts" ) + ray = _get_ray() if actor_index is not None: if not (0 <= actor_index < len(self.actors)): raise IndexError( @@ -1366,6 +1389,7 @@ def get_cache_usage(self, actor_index: int | None = None) -> float | list[float] "AsyncVLLM service must be launched before getting cache usage" ) + ray = _get_ray() if actor_index is not None: if not (0 <= actor_index < len(self.actors)): raise IndexError( @@ -1678,6 +1702,7 @@ def _select_by_requests(self) -> int: futures = [ actor.get_num_unfinished_requests.remote() for actor in self.actors ] + ray = _get_ray() request_counts = ray.get(futures) # Find the actor with minimum pending requests @@ -1705,6 +1730,7 @@ def _select_by_cache_usage(self) -> int: else: # Query actors directly futures = [actor.get_cache_usage.remote() for actor in self.actors] + ray = _get_ray() cache_usages = ray.get(futures) # Find the actor with minimum cache usage @@ -1844,7 +1870,8 @@ def _is_actor_overloaded(self, actor_index: int) -> bool: futures = [ actor.get_num_unfinished_requests.remote() for actor in self.actors ] - request_counts = ray.get(futures) + ray = _get_ray() + request_counts = ray.get(futures) if not request_counts: return False @@ -1893,8 +1920,9 @@ def get_stats(self) -> dict[str, Any]: cache_futures = [ actor.get_cache_usage.remote() for actor in self.actors ] - request_counts = ray.get(request_futures) - cache_usages = ray.get(cache_futures) + ray = _get_ray() + request_counts = ray.get(request_futures) + cache_usages = ray.get(cache_futures) for i, (requests, cache_usage) in enumerate( zip(request_counts, cache_usages)