Skip to content

Commit 975676d

Browse files
authored
[Feat] Drop-in Torch CUDA Profiler (#27841)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
1 parent 77d702a commit 975676d

File tree

5 files changed

+76
-29
lines changed

5 files changed

+76
-29
lines changed

docs/contributing/profiling.md

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
3939

4040
```bash
4141
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
42-
vllm serve meta-llama/Meta-Llama-3-70B
42+
vllm serve meta-llama/Llama-3.1-8B-Instruct
4343
```
4444

4545
vllm bench command:
4646

4747
```bash
4848
vllm bench serve \
4949
--backend vllm \
50-
--model meta-llama/Meta-Llama-3-70B \
50+
--model meta-llama/Llama-3.1-8B-Instruct \
5151
--dataset-name sharegpt \
5252
--dataset-path sharegpt.json \
5353
--profile \
@@ -70,18 +70,21 @@ apt update
7070
apt install nsight-systems-cli
7171
```
7272

73-
### Example commands and usage
73+
!!! tip
74+
When profiling with `nsys`, it is advisable to set the environment variable `VLLM_WORKER_MULTIPROC_METHOD=spawn`. The default is to use the `fork` method instead of `spawn`. More information on the topic can be found in the [Nsight Systems release notes](https://docs.nvidia.com/nsight-systems/ReleaseNotes/index.html#general-issues).
7475

75-
When profiling with `nsys`, it is advisable to set the environment variable `VLLM_WORKER_MULTIPROC_METHOD=spawn`. The default is to use the `fork` method instead of `spawn`. More information on the topic can be found in the [Nsight Systems release notes](https://docs.nvidia.com/nsight-systems/ReleaseNotes/index.html#general-issues).
76+
The Nsight Systems profiler can be launched with `nsys profile ...`, with a few recommended flags for vLLM: `--trace-fork-before-exec=true --cuda-graph-trace=node`.
77+
78+
### Example commands and usage
7679

7780
#### Offline Inference
7881

79-
For basic usage, you can just append `nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node` before any existing script you would run for offline inference.
82+
For basic usage, you can just append the profiling command before any existing script you would run for offline inference.
8083

8184
The following is an example using the `vllm bench latency` script:
8285

8386
```bash
84-
nsys profile -o report.nsys-rep \
87+
nsys profile \
8588
--trace-fork-before-exec=true \
8689
--cuda-graph-trace=node \
8790
vllm bench latency \
@@ -95,40 +98,29 @@ vllm bench latency \
9598

9699
#### OpenAI Server
97100

98-
To profile the server, you will want to prepend your `vllm serve` command with `nsys profile` just like for offline inference, however you must specify `--delay XX --duration YY` parameters according to the needs of your benchmark. After the duration time has been used up, the server will be killed.
101+
To profile the server, you will want to prepend your `vllm serve` command with `nsys profile` just like for offline inference, but you will need to specify a few other arguments to enable dynamic capture similarly to the Torch Profiler:
99102

100103
```bash
101104
# server
102-
nsys profile -o report.nsys-rep \
105+
VLLM_TORCH_CUDA_PROFILE=1 \
106+
nsys profile \
103107
--trace-fork-before-exec=true \
104108
--cuda-graph-trace=node \
105-
--delay 30 \
106-
--duration 60 \
109+
--capture-range=cudaProfilerApi \
110+
--capture-range-end repeat \
107111
vllm serve meta-llama/Llama-3.1-8B-Instruct
108112

109113
# client
110114
vllm bench serve \
111115
--backend vllm \
112116
--model meta-llama/Llama-3.1-8B-Instruct \
113-
--num-prompts 1 \
114-
--dataset-name random \
115-
--random-input 1024 \
116-
--random-output 512
117-
```
118-
119-
In practice, you should set the `--duration` argument to a large value. Whenever you want the server to stop profiling, run:
120-
121-
```bash
122-
nsys sessions list
123-
```
124-
125-
to get the session id in the form of `profile-XXXXX`, then run:
126-
127-
```bash
128-
nsys stop --session=profile-XXXXX
117+
--dataset-name sharegpt \
118+
--dataset-path sharegpt.json \
119+
--profile \
120+
--num-prompts 2
129121
```
130122

131-
to manually kill the profiler and generate your `nsys-rep` report.
123+
With `--profile`, vLLM will capture a profile for each run of `vllm bench serve`. Once the server is killed, the profiles will all be saved.
132124

133125
#### Analysis
134126

vllm/entrypoints/openai/api_server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1280,10 +1280,16 @@ async def invocations(raw_request: Request):
12801280

12811281

12821282
if envs.VLLM_TORCH_PROFILER_DIR:
1283-
logger.warning(
1283+
logger.warning_once(
12841284
"Torch Profiler is enabled in the API server. This should ONLY be "
12851285
"used for local development!"
12861286
)
1287+
elif envs.VLLM_TORCH_CUDA_PROFILE:
1288+
logger.warning_once(
1289+
"CUDA Profiler is enabled in the API server. This should ONLY be "
1290+
"used for local development!"
1291+
)
1292+
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
12871293

12881294
@router.post("/start_profile")
12891295
async def start_profile(raw_request: Request):

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
8888
VLLM_PLUGINS: list[str] | None = None
8989
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
90+
VLLM_TORCH_CUDA_PROFILE: bool = False
9091
VLLM_TORCH_PROFILER_DIR: str | None = None
9192
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
9293
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
@@ -815,6 +816,11 @@ def get_vllm_port() -> int | None:
815816
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
816817
"VLLM_LORA_RESOLVER_CACHE_DIR", None
817818
),
819+
# Enables torch CUDA profiling if set.
820+
# On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered.
821+
"VLLM_TORCH_CUDA_PROFILE": lambda: bool(
822+
os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0"
823+
),
818824
# Enables torch profiler if set.
819825
# Both AsyncLLM's CPU traces as well as workers'
820826
# traces (CPU & GPU) will be saved under this directory.

vllm/profiler/gpu_profiler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm.logger import init_logger
5+
6+
logger = init_logger(__name__)
7+
8+
9+
class CudaProfilerWrapper:
10+
def __init__(self) -> None:
11+
self._profiler_running = False
12+
# Note: lazy import to avoid dependency issues if CUDA is not available.
13+
import torch.cuda.profiler as cuda_profiler
14+
15+
self._cuda_profiler = cuda_profiler
16+
17+
def start(self) -> None:
18+
try:
19+
self._cuda_profiler.start()
20+
self._profiler_running = True
21+
logger.info_once("Started CUDA profiler")
22+
except Exception as e:
23+
logger.warning_once("Failed to start CUDA profiler: %s", e)
24+
25+
def stop(self) -> None:
26+
if self._profiler_running:
27+
try:
28+
self._cuda_profiler.stop()
29+
logger.info_once("Stopped CUDA profiler")
30+
except Exception as e:
31+
logger.warning_once("Failed to stop CUDA profiler: %s", e)
32+
finally:
33+
self._profiler_running = False
34+
35+
def shutdown(self) -> None:
36+
"""Ensure profiler is stopped when shutting down."""
37+
self.stop()

vllm/v1/worker/gpu_worker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.model_executor.models.interfaces import is_mixture_of_experts
3636
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
3737
from vllm.platforms import current_platform
38+
from vllm.profiler.gpu_profiler import CudaProfilerWrapper
3839
from vllm.sequence import IntermediateTensors
3940
from vllm.tasks import SupportedTask
4041
from vllm.utils.mem_constants import GiB_bytes
@@ -116,6 +117,8 @@ def __init__(
116117
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
117118
),
118119
)
120+
elif envs.VLLM_TORCH_CUDA_PROFILE:
121+
self.profiler = CudaProfilerWrapper()
119122
else:
120123
self.profiler = None
121124

@@ -593,7 +596,10 @@ def profile(self, is_start: bool = True):
593596
else:
594597
self.profiler.stop()
595598
# only print profiler results on rank 0
596-
if self.local_rank == 0:
599+
if (
600+
isinstance(self.profiler, torch.profiler.profile)
601+
and self.local_rank == 0
602+
):
597603
print(
598604
self.profiler.key_averages().table(sort_by="self_cuda_time_total")
599605
)

0 commit comments

Comments
 (0)