Skip to content

Commit 64e39d6

Browse files
[BugFix] Temporary fix for IMA with MTP = 2 and full-cg (#28315)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 1b82fb0 commit 64e39d6

File tree

2 files changed

+80
-13
lines changed

2 files changed

+80
-13
lines changed

vllm/config/compilation.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.logger import init_logger
1919
from vllm.platforms import current_platform
2020
from vllm.utils.import_utils import resolve_obj_by_qualname
21+
from vllm.utils.math_utils import round_up
2122
from vllm.utils.torch_utils import is_torch_equal_or_newer
2223

2324
if TYPE_CHECKING:
@@ -773,19 +774,8 @@ def post_init_cudagraph_sizes(self) -> None:
773774
if self.cudagraph_capture_sizes:
774775
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
775776

776-
# pre-compute the mapping from batch size to padded graph size
777-
self.bs_to_padded_graph_size = [
778-
0 for i in range(self.max_cudagraph_capture_size + 1)
779-
]
780-
for end, start in zip(
781-
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
782-
[0] + self.cudagraph_capture_sizes,
783-
):
784-
for bs in range(start, end):
785-
if bs == start:
786-
self.bs_to_padded_graph_size[bs] = start
787-
else:
788-
self.bs_to_padded_graph_size[bs] = end
777+
# May get recomputed in the model runner if adjustment is needed for spec-decode
778+
self.compute_bs_to_padded_graph_size()
789779

790780
def set_splitting_ops_for_v1(self):
791781
# NOTE: this function needs to be called only when mode is
@@ -922,3 +912,64 @@ def custom_op_log_check(self):
922912
enable_str,
923913
op,
924914
)
915+
916+
def adjust_cudagraph_sizes_for_spec_decode(
917+
self, uniform_decode_query_len: int, tensor_parallel_size: int
918+
):
919+
multiple_of = uniform_decode_query_len
920+
if tensor_parallel_size > 1:
921+
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
922+
if (
923+
multiple_of % uniform_decode_query_len != 0
924+
or multiple_of % tensor_parallel_size != 0
925+
):
926+
raise ValueError(
927+
f"Can't determine cudagraph shapes that are both a "
928+
f"multiple of {uniform_decode_query_len} "
929+
f"(num_speculative_tokens + 1) required by spec-decode "
930+
f"and {tensor_parallel_size} (tensor_parallel_size) "
931+
f"required by sequence parallelism please adjust "
932+
f"num_speculative_tokens or disable sequence parallelism"
933+
)
934+
935+
if not self.cudagraph_capture_sizes or multiple_of <= 1:
936+
return
937+
938+
assert self.max_cudagraph_capture_size is not None
939+
rounded_sizes = sorted(
940+
set(
941+
round_up(size, multiple_of)
942+
for size in self.cudagraph_capture_sizes
943+
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
944+
)
945+
)
946+
947+
if len(rounded_sizes) == 0:
948+
logger.warning(
949+
"No valid cudagraph sizes after rounding to multiple of "
950+
" num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
951+
" or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
952+
multiple_of,
953+
)
954+
return
955+
956+
self.max_cudagraph_capture_size = rounded_sizes[-1]
957+
self.cudagraph_capture_sizes = rounded_sizes
958+
959+
# Recompute after adjusting the cudagraph sizes
960+
self.compute_bs_to_padded_graph_size()
961+
962+
def compute_bs_to_padded_graph_size(self):
963+
# pre-compute the mapping from batch size to padded graph size
964+
self.bs_to_padded_graph_size = [
965+
0 for i in range(self.max_cudagraph_capture_size + 1)
966+
]
967+
for end, start in zip(
968+
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
969+
[0] + self.cudagraph_capture_sizes,
970+
):
971+
for bs in range(start, end):
972+
if bs == start:
973+
self.bs_to_padded_graph_size[bs] = start
974+
else:
975+
self.bs_to_padded_graph_size[bs] = end

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4332,6 +4332,22 @@ def _check_and_update_cudagraph_mode(
43324332
"and make sure compilation mode is VLLM_COMPILE"
43334333
)
43344334

4335+
# if we have dedicated decode cudagraphs, and spec-decode is enabled,
4336+
# we need to adjust the cudagraph sizes to be a multiple of the uniform
4337+
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
4338+
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
4339+
# Will be removed in the near future when we have seperate cudagraph capture
4340+
# sizes for decode and mixed prefill-decode.
4341+
if (
4342+
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
4343+
and cudagraph_mode.separate_routine()
4344+
and self.uniform_decode_query_len > 1
4345+
):
4346+
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
4347+
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
4348+
)
4349+
self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes
4350+
43354351
# Trigger cudagraph dispatching keys initialization after
43364352
# resolved cudagraph mode.
43374353
self.cudagraph_dispatcher.initialize_cudagraph_keys(

0 commit comments

Comments
 (0)