Skip to content

Commit 60bb901

Browse files
wip
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 5d67c78 commit 60bb901

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

vllm/config/compilation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -961,13 +961,17 @@ def custom_op_log_check(self):
961961
)
962962

963963
def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int):
964-
if not self.cudagraph_capture_sizes:
964+
if not self.cudagraph_capture_sizes or multiple_of <= 1:
965965
return
966966

967+
assert self.max_cudagraph_capture_size is not None
968+
967969
rounded_sizes = sorted(
968-
round_up(size, multiple_of)
969-
for size in self.cudagraph_capture_sizes
970-
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
970+
set(
971+
round_up(size, multiple_of)
972+
for size in self.cudagraph_capture_sizes
973+
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
974+
)
971975
)
972976

973977
if len(rounded_sizes) == 0:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4325,6 +4325,8 @@ def _check_and_update_cudagraph_mode(
43254325
# we need to adjust the cudagraph sizes to be a multiple of the uniform
43264326
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
43274327
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
4328+
# Will be removed in the near future when we have seperate cudagraph capture
4329+
# sizes for decode and mixed prefill-decode.
43284330
if (
43294331
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
43304332
and cudagraph_mode.separate_routine()
@@ -4333,6 +4335,8 @@ def _check_and_update_cudagraph_mode(
43334335
self.compilation_config.adjust_cudagraph_sizes_to_be_multipe_of(
43344336
self.uniform_decode_query_len
43354337
)
4338+
self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes
4339+
43364340
self.compilation_config.compute_bs_to_padded_graph_size()
43374341

43384342
# Trigger cudagraph dispatching keys initialization after

0 commit comments

Comments
 (0)