Skip to content

Commit 4ae8b23

Browse files
LucasWilkinsonkhluu
authored andcommitted
[BugFix] Temporary fix for IMA with MTP = 2 and full-cg (#28315)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> (cherry picked from commit 64e39d6)
1 parent 48239bb commit 4ae8b23

File tree

2 files changed

+80
-13
lines changed

2 files changed

+80
-13
lines changed

vllm/config/compilation.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.logger import init_logger
1919
from vllm.platforms import current_platform
2020
from vllm.utils.import_utils import resolve_obj_by_qualname
21+
from vllm.utils.math_utils import round_up
2122
from vllm.utils.torch_utils import is_torch_equal_or_newer
2223

2324
if TYPE_CHECKING:
@@ -752,19 +753,8 @@ def post_init_cudagraph_sizes(self) -> None:
752753
if self.cudagraph_capture_sizes:
753754
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
754755

755-
# pre-compute the mapping from batch size to padded graph size
756-
self.bs_to_padded_graph_size = [
757-
0 for i in range(self.max_cudagraph_capture_size + 1)
758-
]
759-
for end, start in zip(
760-
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
761-
[0] + self.cudagraph_capture_sizes,
762-
):
763-
for bs in range(start, end):
764-
if bs == start:
765-
self.bs_to_padded_graph_size[bs] = start
766-
else:
767-
self.bs_to_padded_graph_size[bs] = end
756+
# May get recomputed in the model runner if adjustment is needed for spec-decode
757+
self.compute_bs_to_padded_graph_size()
768758

769759
def set_splitting_ops_for_v1(self):
770760
# NOTE: this function needs to be called only when mode is
@@ -901,3 +891,64 @@ def custom_op_log_check(self):
901891
enable_str,
902892
op,
903893
)
894+
895+
def adjust_cudagraph_sizes_for_spec_decode(
896+
self, uniform_decode_query_len: int, tensor_parallel_size: int
897+
):
898+
multiple_of = uniform_decode_query_len
899+
if tensor_parallel_size > 1:
900+
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
901+
if (
902+
multiple_of % uniform_decode_query_len != 0
903+
or multiple_of % tensor_parallel_size != 0
904+
):
905+
raise ValueError(
906+
f"Can't determine cudagraph shapes that are both a "
907+
f"multiple of {uniform_decode_query_len} "
908+
f"(num_speculative_tokens + 1) required by spec-decode "
909+
f"and {tensor_parallel_size} (tensor_parallel_size) "
910+
f"required by sequence parallelism please adjust "
911+
f"num_speculative_tokens or disable sequence parallelism"
912+
)
913+
914+
if not self.cudagraph_capture_sizes or multiple_of <= 1:
915+
return
916+
917+
assert self.max_cudagraph_capture_size is not None
918+
rounded_sizes = sorted(
919+
set(
920+
round_up(size, multiple_of)
921+
for size in self.cudagraph_capture_sizes
922+
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
923+
)
924+
)
925+
926+
if len(rounded_sizes) == 0:
927+
logger.warning(
928+
"No valid cudagraph sizes after rounding to multiple of "
929+
" num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
930+
" or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
931+
multiple_of,
932+
)
933+
return
934+
935+
self.max_cudagraph_capture_size = rounded_sizes[-1]
936+
self.cudagraph_capture_sizes = rounded_sizes
937+
938+
# Recompute after adjusting the cudagraph sizes
939+
self.compute_bs_to_padded_graph_size()
940+
941+
def compute_bs_to_padded_graph_size(self):
942+
# pre-compute the mapping from batch size to padded graph size
943+
self.bs_to_padded_graph_size = [
944+
0 for i in range(self.max_cudagraph_capture_size + 1)
945+
]
946+
for end, start in zip(
947+
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
948+
[0] + self.cudagraph_capture_sizes,
949+
):
950+
for bs in range(start, end):
951+
if bs == start:
952+
self.bs_to_padded_graph_size[bs] = start
953+
else:
954+
self.bs_to_padded_graph_size[bs] = end

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4315,6 +4315,22 @@ def _check_and_update_cudagraph_mode(
43154315
"and make sure compilation mode is VLLM_COMPILE"
43164316
)
43174317

4318+
# if we have dedicated decode cudagraphs, and spec-decode is enabled,
4319+
# we need to adjust the cudagraph sizes to be a multiple of the uniform
4320+
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
4321+
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
4322+
# Will be removed in the near future when we have seperate cudagraph capture
4323+
# sizes for decode and mixed prefill-decode.
4324+
if (
4325+
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
4326+
and cudagraph_mode.separate_routine()
4327+
and self.uniform_decode_query_len > 1
4328+
):
4329+
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
4330+
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
4331+
)
4332+
self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes
4333+
43184334
# Trigger cudagraph dispatching keys initialization after
43194335
# resolved cudagraph mode.
43204336
self.cudagraph_dispatcher.initialize_cudagraph_keys(

0 commit comments

Comments
 (0)