Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
517b672
fixes and refactors spec-decode cudagraph
fhl2000 Aug 25, 2025
40e1ccb
remove build_for_cudagraph_capture
fhl2000 Aug 26, 2025
3550717
support capturing mutiple uniform_query_len
fhl2000 Aug 26, 2025
a142f14
fix typo
fhl2000 Aug 26, 2025
f7d73f8
fix typo
fhl2000 Aug 26, 2025
02390fc
fix broken examples/offline_inference/spec_decode.py
fhl2000 Aug 26, 2025
198fb66
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Sep 4, 2025
14c6918
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 7, 2025
ec02778
fix pre-commit
fhl2000 Sep 7, 2025
6b90770
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 9, 2025
286677f
revert spec_decode.py
fhl2000 Sep 10, 2025
874639c
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 14, 2025
0eda111
address comments
fhl2000 Sep 15, 2025
9c50e6e
revert build_for_cudagraph_capturing
fhl2000 Sep 15, 2025
e4a1a78
remove unnecessary assertion
fhl2000 Sep 15, 2025
ce32326
solving conflicts/Merge remote-tracking branch 'origin/main' into fix…
fhl2000 Sep 16, 2025
ad5ba70
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 20, 2025
691c21e
fixes for ubatching
fhl2000 Sep 21, 2025
43b2753
fix CI
fhl2000 Sep 23, 2025
fde10ba
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Sep 28, 2025
0a3fe05
fix
fhl2000 Sep 28, 2025
804598b
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 29, 2025
40bd81b
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Oct 6, 2025
a51344e
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Oct 14, 2025
d170341
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Oct 14, 2025
0ee4aef
WIP:address dp padding issue
fhl2000 Oct 14, 2025
a4872bc
clean up
fhl2000 Oct 15, 2025
872015e
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Oct 15, 2025
d1499c2
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Nov 1, 2025
c18486a
refactor eagle dummy run
fhl2000 Nov 1, 2025
9b99056
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Nov 4, 2025
299ce7d
fix drafter when enforce_eager
fhl2000 Nov 4, 2025
b5c315a
fix pre-commit
fhl2000 Nov 4, 2025
25d3f3b
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Nov 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -69,6 +71,7 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--compilation-config", type=str, default="")
return parser.parse_args()


Expand Down Expand Up @@ -133,6 +136,9 @@ def main():
max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
compilation_config=(
json.loads(args.compilation_config) if args.compilation_config else None
),
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
Expand Down
82 changes: 67 additions & 15 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3362,12 +3362,26 @@ def compute_hash(self) -> str:
usedforsecurity=False).hexdigest()[:10]
return hash_str

def pad_for_cudagraph(self, batch_size: int) -> int:
# if batch_size > self.compilation_config.max_capture_size,
def pad_for_cudagraph(self,
batch_size: int,
uniform_aligned: bool = False) -> int:
""" Get the padded graph size for the batch size.
uniform_aligned: if True, means the padding batch size would be
divisible by the uniform_decode_len for the main model.
For drafter, caller should make sure uniform_aligned is False because
drafter's uniform_decode_len is 1.
"""

# if batch_size > self.compilation_config.max_capture_size when
# uniform_aligned is False, or batch_size > self.compilation_config.
# max_uniform_capture_size when uniform_aligned is True,
# it should raise an IndexError.
# the caller should make sure the batch_size is within the range,
# i.e., batch_size <= self.compilation_config.max_capture_size
return self.compilation_config.bs_to_padded_graph_size[batch_size]
# the caller should make sure the batch_size is within the range
if not uniform_aligned:
return self.compilation_config.bs_to_padded_graph_size[batch_size]
else:
return self.compilation_config.\
bs_to_padded_graph_size_uniform[batch_size]

@staticmethod
def _get_quantization_config(
Expand Down Expand Up @@ -3635,14 +3649,24 @@ def __post_init__(self):
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
def update_sizes_for_sequence_parallelism(
self,
possible_sizes: list,
uniform_possible_sizes: Optional[list] = None
) -> tuple[list, Optional[list]]:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
removed_sizes = [
size for size in possible_sizes
if size % self.parallel_config.tensor_parallel_size != 0
]
removed_uniform_sizes = []
if uniform_possible_sizes is not None:
removed_uniform_sizes = [
size for size in uniform_possible_sizes
if size % self.parallel_config.tensor_parallel_size != 0
]
removed_sizes = list(set(removed_sizes + removed_uniform_sizes))
if removed_sizes:
logger.warning(
"Batch sizes %s are removed because they are not "
Expand All @@ -3653,7 +3677,10 @@ def update_sizes_for_sequence_parallelism(self,
return [
size for size in possible_sizes
if size % self.parallel_config.tensor_parallel_size == 0
]
], [
size for size in uniform_possible_sizes
if size % self.parallel_config.tensor_parallel_size == 0
] if uniform_possible_sizes else None

def _set_cudagraph_sizes(self):
"""
Expand Down Expand Up @@ -3684,17 +3711,20 @@ def _set_cudagraph_sizes(self):
"""

# calculate the default `batch_size_capture_list`
batch_size_capture_list = []
uniform_batch_size_capture_list = []
uniform_decode_len = 1 if not self.speculative_config else \
1 + self.speculative_config.num_speculative_tokens
if not envs.VLLM_USE_V1:
batch_size_capture_list = []
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:

possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
possible_sizes = self.update_sizes_for_sequence_parallelism(
possible_sizes)
possible_sizes, _ = \
self.update_sizes_for_sequence_parallelism(possible_sizes)

# find the minimum size that is larger than max_num_seqs,
# which then becomes the max_batchsize_to_capture
Expand All @@ -3714,7 +3744,6 @@ def _set_cudagraph_sizes(self):
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
Expand All @@ -3726,18 +3755,41 @@ def _set_cudagraph_sizes(self):
batch_size_capture_list = sorted(cuda_graph_sizes)
else:
raise TypeError(f"Invalid value for {cuda_graph_sizes=}.")

# we maintain a separate list of uniform-decode capture sizes,
# since for spec-decode, we may need capture sizes being
# divisible by uniform_decode_len(>1).

# Derive uniform-decode capture sizes via projection: for each
# non-uniform capture size i, take the max multiple of
# uniform_decode_len that is not greater than i.
projected_sizes: set[int] = set()
for size in batch_size_capture_list:
proj = (size // uniform_decode_len) * uniform_decode_len
if proj >= uniform_decode_len:
projected_sizes.add(proj)
uniform_batch_size_capture_list = sorted(projected_sizes)
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
batch_size_capture_list = \
self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
batch_size_capture_list, uniform_batch_size_capture_list = \
self.update_sizes_for_sequence_parallelism(
batch_size_capture_list,
uniform_batch_size_capture_list)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
batch_size_capture_list = [
size for size in batch_size_capture_list
if size <= max_num_tokens
]
max_num_decode_tokens = self.scheduler_config.max_num_seqs * \
uniform_decode_len
uniform_batch_size_capture_list = [
size for size in uniform_batch_size_capture_list
if size <= max_num_decode_tokens
]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
batch_size_capture_list, uniform_batch_size_capture_list,
uniform_decode_len)

def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config
Expand Down
46 changes: 44 additions & 2 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ class CompilationConfig:

max_capture_size: int = field(default=None, init=False) # type: ignore
"""not configurable, computed after init"""
uniform_cudagraph_capture_sizes: Optional[list[int]] = None
"""
List for capture sizes for uniform decode for the main model. Its elements
should be multiples of uniform_decode_len(1 for common pure decode, or
1+num_speculative_tokens for speculative decode).
Not configurable, computed after init
"""
max_uniform_capture_size: int = field(default=None,
init=False) # type: ignore
"""not configurable, computed after init"""
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
Expand All @@ -313,6 +323,10 @@ class CompilationConfig:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_capture_size],
we can optimize it to list[int] for better lookup performance."""
bs_to_padded_graph_size_uniform: list[int] = field(
default=None, # type: ignore
init=False)
"""same as bs_to_padded_graph_size, but for uniform capture sizes"""

# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter,
Expand Down Expand Up @@ -373,6 +387,7 @@ def __repr__(self) -> str:
"disabled_custom_ops": True,
"compilation_time": True,
"bs_to_padded_graph_size": True,
"bs_to_padded_graph_size_uniform": True,
"traced_files": True,
"inductor_compile_config": {
"post_grad_custom_post_pass": True,
Expand Down Expand Up @@ -482,8 +497,9 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config)

def init_with_cudagraph_sizes(self,
cudagraph_capture_sizes: list[int]) -> None:
def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int],
uniform_cudagraph_capture_sizes: list[int],
uniform_decode_len: int) -> None:
"""To complete the initialization of config,
we need to know the cudagraph sizes."""

Expand All @@ -497,6 +513,12 @@ def init_with_cudagraph_sizes(self,
" %s is overridden by config %s"),
cudagraph_capture_sizes, dedup_sizes)
self.cudagraph_capture_sizes = dedup_sizes
if envs.VLLM_USE_V1:
# recompute uniform_cudagraph_capture_sizes based on the
# dedup_sizes(computed from config) and uniform_decode_len
uniform_cudagraph_capture_sizes = sorted(
set((size // uniform_decode_len) * uniform_decode_len
for size in dedup_sizes if size >= uniform_decode_len))

computed_compile_sizes = []
if self.compile_sizes is not None:
Expand All @@ -518,6 +540,11 @@ def init_with_cudagraph_sizes(self,
self.max_capture_size = self.cudagraph_capture_sizes[
0] if self.cudagraph_capture_sizes else 0

self.uniform_cudagraph_capture_sizes = sorted(
uniform_cudagraph_capture_sizes, reverse=True)
self.max_uniform_capture_size = self.uniform_cudagraph_capture_sizes[
0] if self.uniform_cudagraph_capture_sizes else 0

# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_capture_size + 1)
Expand All @@ -532,6 +559,21 @@ def init_with_cudagraph_sizes(self,
self.bs_to_padded_graph_size[
self.max_capture_size] = self.max_capture_size

# pre-compute the mapping for uniform decode padding.
self.bs_to_padded_graph_size_uniform = [
0 for i in range(self.max_uniform_capture_size + 1)
]

for end, start in zip(self.uniform_cudagraph_capture_sizes,
self.uniform_cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size_uniform[bs] = start
else:
self.bs_to_padded_graph_size_uniform[bs] = end
self.bs_to_padded_graph_size_uniform[self.max_uniform_capture_size] =\
self.max_uniform_capture_size

def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called only when level is
# CompilationLevel.PIECEWISE
Expand Down
9 changes: 8 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,20 @@ class BatchDescriptor(NamedTuple):
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
uniform_query_len: int = 0
"""
For non-uniform batches, should set to 0 for uniquely identifying the batch.
For uniform batches, it is the max_query_len of a uniform batch.
"""

@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, uniform_decode=False)
return BatchDescriptor(self.num_tokens,
uniform_decode=False,
uniform_query_len=0)


def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
Expand Down
16 changes: 0 additions & 16 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,22 +539,6 @@ def build(self,
)
return attn_metadata

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata

assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

m.max_query_len = 1 # decode-only

return self.build(0, m)

def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
Expand Down
19 changes: 1 addition & 18 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
AttentionMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec

M = TypeVar("M")
Expand Down Expand Up @@ -37,19 +36,3 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
dtype=torch.int32,
device=device,
)

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata

assert m.num_reqs == m.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

m.max_query_len = 1 # decode-only

return self.build(0, m)
16 changes: 0 additions & 16 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,22 +621,6 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens=seq_lens_device,
)

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs <= (m.num_actual_tokens *
self.reorder_batch_threshold), \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

assert m.max_query_len <= self.reorder_batch_threshold # decode only

return self.build(0, m)

def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
Expand Down
9 changes: 0 additions & 9 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
self.aot_sliding_window: Optional[tuple[int, int]] = None
self.total_tokens: int = 0

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
self.total_tokens = self.model_config.max_model_len \
* self.vllm_config.scheduler_config.max_num_partial_prefills
res = self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
self.total_tokens = 0
return res

def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
Expand Down
Loading