From 6e9926d6b6d66c8a43bf96515ebaea1bc5c6f909 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 7 Nov 2025 18:33:52 +0000 Subject: [PATCH 1/9] [WIP] Tmp fix for IMA with MTP = 2 and full-cg Temp fix for https://github.com/vllm-project/vllm/issues/28207 Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 29 +++++++++++++++-------------- vllm/config/vllm.py | 13 +++++++++++++ vllm/v1/worker/gpu_model_runner.py | 14 ++++++++++++++ 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index e1d60ee84d89..2e7077ebd18d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -822,20 +822,6 @@ def post_init_cudagraph_sizes(self) -> None: if self.cudagraph_capture_sizes: assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size - # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [ - 0 for i in range(self.max_cudagraph_capture_size + 1) - ] - for end, start in zip( - self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], - [0] + self.cudagraph_capture_sizes, - ): - for bs in range(start, end): - if bs == start: - self.bs_to_padded_graph_size[bs] = start - else: - self.bs_to_padded_graph_size[bs] = end - def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when mode is # CompilationMode.VLLM_COMPILE @@ -972,3 +958,18 @@ def custom_op_log_check(self): enable_str, op, ) + + def compute_bs_to_padded_graph_size(self): + # pre-compute the mapping from batch size to padded graph size + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_cudagraph_capture_size + 1) + ] + for end, start in zip( + self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], + [0] + self.cudagraph_capture_sizes, + ): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size[bs] = start + else: + self.bs_to_padded_graph_size[bs] = end diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 60458b26944a..3dd553b0bee7 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -24,6 +24,7 @@ from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid +from vllm.utils.math_utils import round_up from .cache import CacheConfig from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode @@ -215,6 +216,18 @@ def pad_for_cudagraph(self, batch_size: int) -> int: # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size return self.compilation_config.bs_to_padded_graph_size[batch_size] + def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int): + new_sizes = sorted( + [ + round_up(size, multiple_of) + for size in self.compilation_config.cudagraph_capture_sizes + ] + ) + if new_sizes[-1] > self.compilation_config.max_cudagraph_capture_size: + new_sizes = new_sizes[:-1] + self.compilation_config.max_cudagraph_capture_size = new_sizes[-1] + self.compilation_config.cudagraph_capture_sizes = new_sizes + def enable_trace_function_call_for_thread(self) -> None: """ Set up function tracing for the current thread, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b14b6b1c3f52..42cbb6fc0e2b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4321,6 +4321,20 @@ def _check_and_update_cudagraph_mode( "and make sure compilation mode is VLLM_COMPILE" ) + # if we have dedicated decode cudagraphs, and spec-decode is enabled, + # we need to adjust the cudagraph sizes to be a multiple of the uniform + # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207 + # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + and self.uniform_decode_query_len > 1 + ): + self.vllm_config.adjust_cudagraph_sizes_to_be_multipe_of( + self.uniform_decode_query_len + ) + self.vllm_config.compilation_config.compute_bs_to_padded_graph_size() + # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. self.cudagraph_dispatcher.initialize_cudagraph_keys( From 00c4e138028b95c57af13897afedccaad7bb20d2 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 7 Nov 2025 18:38:56 +0000 Subject: [PATCH 2/9] cleanup Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 13 +++++++++++++ vllm/config/vllm.py | 13 ------------- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2e7077ebd18d..7c6214b8f251 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -18,6 +18,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.math_utils import round_up from vllm.utils.torch_utils import is_torch_equal_or_newer if TYPE_CHECKING: @@ -959,6 +960,18 @@ def custom_op_log_check(self): op, ) + def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int): + new_sizes = sorted( + [ + round_up(size, multiple_of) + for size in self.compilation_config.cudagraph_capture_sizes + ] + ) + if new_sizes[-1] > self.compilation_config.max_cudagraph_capture_size: + new_sizes = new_sizes[:-1] + self.compilation_config.max_cudagraph_capture_size = new_sizes[-1] + self.compilation_config.cudagraph_capture_sizes = new_sizes + def compute_bs_to_padded_graph_size(self): # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = [ diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 3dd553b0bee7..60458b26944a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -24,7 +24,6 @@ from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid -from vllm.utils.math_utils import round_up from .cache import CacheConfig from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode @@ -216,18 +215,6 @@ def pad_for_cudagraph(self, batch_size: int) -> int: # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size return self.compilation_config.bs_to_padded_graph_size[batch_size] - def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int): - new_sizes = sorted( - [ - round_up(size, multiple_of) - for size in self.compilation_config.cudagraph_capture_sizes - ] - ) - if new_sizes[-1] > self.compilation_config.max_cudagraph_capture_size: - new_sizes = new_sizes[:-1] - self.compilation_config.max_cudagraph_capture_size = new_sizes[-1] - self.compilation_config.cudagraph_capture_sizes = new_sizes - def enable_trace_function_call_for_thread(self) -> None: """ Set up function tracing for the current thread, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 42cbb6fc0e2b..75690f13ac17 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4330,10 +4330,10 @@ def _check_and_update_cudagraph_mode( and cudagraph_mode.separate_routine() and self.uniform_decode_query_len > 1 ): - self.vllm_config.adjust_cudagraph_sizes_to_be_multipe_of( + self.compilation_config.adjust_cudagraph_sizes_to_be_multipe_of( self.uniform_decode_query_len ) - self.vllm_config.compilation_config.compute_bs_to_padded_graph_size() + self.compilation_config.compute_bs_to_padded_graph_size() # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. From 5d67c78e0d95a53e6a536fe897bbcb50250a599e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 11 Nov 2025 18:12:29 -0800 Subject: [PATCH 3/9] wip Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 7c6214b8f251..ef62232143be 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -961,16 +961,26 @@ def custom_op_log_check(self): ) def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int): - new_sizes = sorted( - [ - round_up(size, multiple_of) - for size in self.compilation_config.cudagraph_capture_sizes - ] + if not self.cudagraph_capture_sizes: + return + + rounded_sizes = sorted( + round_up(size, multiple_of) + for size in self.cudagraph_capture_sizes + if round_up(size, multiple_of) <= self.max_cudagraph_capture_size ) - if new_sizes[-1] > self.compilation_config.max_cudagraph_capture_size: - new_sizes = new_sizes[:-1] - self.compilation_config.max_cudagraph_capture_size = new_sizes[-1] - self.compilation_config.cudagraph_capture_sizes = new_sizes + + if len(rounded_sizes) == 0: + logger.warning( + "No valid cudagraph sizes after rounding to multiple of " + " num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens" + " or max_cudagraph_capture_size (or cudagraph_capture_sizes)", + multiple_of, + ) + return + + self.max_cudagraph_capture_size = rounded_sizes[-1] + self.cudagraph_capture_sizes = rounded_sizes def compute_bs_to_padded_graph_size(self): # pre-compute the mapping from batch size to padded graph size From 60bb90138d7613ef6b5b3f424412bea3b99fc48f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 11 Nov 2025 20:39:02 -0800 Subject: [PATCH 4/9] wip Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 12 ++++++++---- vllm/v1/worker/gpu_model_runner.py | 4 ++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ef62232143be..03de5129efa9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -961,13 +961,17 @@ def custom_op_log_check(self): ) def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int): - if not self.cudagraph_capture_sizes: + if not self.cudagraph_capture_sizes or multiple_of <= 1: return + assert self.max_cudagraph_capture_size is not None + rounded_sizes = sorted( - round_up(size, multiple_of) - for size in self.cudagraph_capture_sizes - if round_up(size, multiple_of) <= self.max_cudagraph_capture_size + set( + round_up(size, multiple_of) + for size in self.cudagraph_capture_sizes + if round_up(size, multiple_of) <= self.max_cudagraph_capture_size + ) ) if len(rounded_sizes) == 0: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 75690f13ac17..d8ed89f4de60 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4325,6 +4325,8 @@ def _check_and_update_cudagraph_mode( # we need to adjust the cudagraph sizes to be a multiple of the uniform # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207 # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 + # Will be removed in the near future when we have seperate cudagraph capture + # sizes for decode and mixed prefill-decode. if ( cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and cudagraph_mode.separate_routine() @@ -4333,6 +4335,8 @@ def _check_and_update_cudagraph_mode( self.compilation_config.adjust_cudagraph_sizes_to_be_multipe_of( self.uniform_decode_query_len ) + self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes + self.compilation_config.compute_bs_to_padded_graph_size() # Trigger cudagraph dispatching keys initialization after From 529078ed9c847b174d111e748f0ae41a52e6fa0d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 11 Nov 2025 20:41:07 -0800 Subject: [PATCH 5/9] cleanup Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 03de5129efa9..da7a50c8a4b1 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -965,7 +965,6 @@ def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int): return assert self.max_cudagraph_capture_size is not None - rounded_sizes = sorted( set( round_up(size, multiple_of) From bbf54737ec6147c90f2ab42494f9647ea65d7cfe Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Nov 2025 16:45:01 -0800 Subject: [PATCH 6/9] fix tests Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index da7a50c8a4b1..d3950f7715e6 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -755,6 +755,9 @@ def __post_init__(self) -> None: if self.backend == "": self.backend = current_platform.simple_compile_backend + # Gets recomputed in the model runner but compute it here for testing. + self.post_init_cudagraph_sizes() + def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: """ Initialize the backend for the compilation config from a vllm config. From 3f3848d276d3b32613cdf41ea79e949e69b78239 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 15 Nov 2025 11:09:57 -0800 Subject: [PATCH 7/9] fix Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index f6c00cc725a5..4d5bcb5244a1 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -706,9 +706,6 @@ def __post_init__(self) -> None: if self.backend == "": self.backend = current_platform.simple_compile_backend - # Gets recomputed in the model runner but compute it here for testing. - self.post_init_cudagraph_sizes() - def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: """ Initialize the backend for the compilation config from a vllm config. @@ -777,6 +774,9 @@ def post_init_cudagraph_sizes(self) -> None: if self.cudagraph_capture_sizes: assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size + # Gets recomputed in the model runner but compute it here for testing. + self.compute_bs_to_padded_graph_size() + def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when mode is # CompilationMode.VLLM_COMPILE From a1ecdc6531e04a40c08cb00c8f259647708eb126 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 15 Nov 2025 13:27:45 -0800 Subject: [PATCH 8/9] fix tests Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 4d5bcb5244a1..1a6a0b7b5c18 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -774,7 +774,7 @@ def post_init_cudagraph_sizes(self) -> None: if self.cudagraph_capture_sizes: assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size - # Gets recomputed in the model runner but compute it here for testing. + # May get recomputed in the model runner if adjustment is needed for spec-decode self.compute_bs_to_padded_graph_size() def set_splitting_ops_for_v1(self): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a4696bfecfe7..4dda7e952ad5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4357,8 +4357,8 @@ def _check_and_update_cudagraph_mode( self.uniform_decode_query_len ) self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes - - self.compilation_config.compute_bs_to_padded_graph_size() + # Recompute after adjusting the cudagraph sizes + self.compilation_config.compute_bs_to_padded_graph_size() # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. From ea2e6498e6a512eaedc1395d0da6c6f964d76d6e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 16 Nov 2025 13:03:22 -0800 Subject: [PATCH 9/9] add warning Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 23 ++++++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 6 ++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 1a6a0b7b5c18..088d0b1af757 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -913,7 +913,25 @@ def custom_op_log_check(self): op, ) - def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int): + def adjust_cudagraph_sizes_for_spec_decode( + self, uniform_decode_query_len: int, tensor_parallel_size: int + ): + multiple_of = uniform_decode_query_len + if tensor_parallel_size > 1: + multiple_of = max(uniform_decode_query_len, tensor_parallel_size) + if ( + multiple_of % uniform_decode_query_len != 0 + or multiple_of % tensor_parallel_size != 0 + ): + raise ValueError( + f"Can't determine cudagraph shapes that are both a " + f"multiple of {uniform_decode_query_len} " + f"(num_speculative_tokens + 1) required by spec-decode " + f"and {tensor_parallel_size} (tensor_parallel_size) " + f"required by sequence parallelism please adjust " + f"num_speculative_tokens or disable sequence parallelism" + ) + if not self.cudagraph_capture_sizes or multiple_of <= 1: return @@ -938,6 +956,9 @@ def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int): self.max_cudagraph_capture_size = rounded_sizes[-1] self.cudagraph_capture_sizes = rounded_sizes + # Recompute after adjusting the cudagraph sizes + self.compute_bs_to_padded_graph_size() + def compute_bs_to_padded_graph_size(self): # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = [ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4dda7e952ad5..a87c40e3421d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4353,12 +4353,10 @@ def _check_and_update_cudagraph_mode( and cudagraph_mode.separate_routine() and self.uniform_decode_query_len > 1 ): - self.compilation_config.adjust_cudagraph_sizes_to_be_multipe_of( - self.uniform_decode_query_len + self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( + self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size ) self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes - # Recompute after adjusting the cudagraph sizes - self.compilation_config.compute_bs_to_padded_graph_size() # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode.