|
18 | 18 | from vllm.logger import init_logger |
19 | 19 | from vllm.platforms import current_platform |
20 | 20 | from vllm.utils.import_utils import resolve_obj_by_qualname |
| 21 | +from vllm.utils.math_utils import round_up |
21 | 22 | from vllm.utils.torch_utils import is_torch_equal_or_newer |
22 | 23 |
|
23 | 24 | if TYPE_CHECKING: |
@@ -752,19 +753,8 @@ def post_init_cudagraph_sizes(self) -> None: |
752 | 753 | if self.cudagraph_capture_sizes: |
753 | 754 | assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size |
754 | 755 |
|
755 | | - # pre-compute the mapping from batch size to padded graph size |
756 | | - self.bs_to_padded_graph_size = [ |
757 | | - 0 for i in range(self.max_cudagraph_capture_size + 1) |
758 | | - ] |
759 | | - for end, start in zip( |
760 | | - self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], |
761 | | - [0] + self.cudagraph_capture_sizes, |
762 | | - ): |
763 | | - for bs in range(start, end): |
764 | | - if bs == start: |
765 | | - self.bs_to_padded_graph_size[bs] = start |
766 | | - else: |
767 | | - self.bs_to_padded_graph_size[bs] = end |
| 756 | + # May get recomputed in the model runner if adjustment is needed for spec-decode |
| 757 | + self.compute_bs_to_padded_graph_size() |
768 | 758 |
|
769 | 759 | def set_splitting_ops_for_v1(self): |
770 | 760 | # NOTE: this function needs to be called only when mode is |
@@ -901,3 +891,64 @@ def custom_op_log_check(self): |
901 | 891 | enable_str, |
902 | 892 | op, |
903 | 893 | ) |
| 894 | + |
| 895 | + def adjust_cudagraph_sizes_for_spec_decode( |
| 896 | + self, uniform_decode_query_len: int, tensor_parallel_size: int |
| 897 | + ): |
| 898 | + multiple_of = uniform_decode_query_len |
| 899 | + if tensor_parallel_size > 1: |
| 900 | + multiple_of = max(uniform_decode_query_len, tensor_parallel_size) |
| 901 | + if ( |
| 902 | + multiple_of % uniform_decode_query_len != 0 |
| 903 | + or multiple_of % tensor_parallel_size != 0 |
| 904 | + ): |
| 905 | + raise ValueError( |
| 906 | + f"Can't determine cudagraph shapes that are both a " |
| 907 | + f"multiple of {uniform_decode_query_len} " |
| 908 | + f"(num_speculative_tokens + 1) required by spec-decode " |
| 909 | + f"and {tensor_parallel_size} (tensor_parallel_size) " |
| 910 | + f"required by sequence parallelism please adjust " |
| 911 | + f"num_speculative_tokens or disable sequence parallelism" |
| 912 | + ) |
| 913 | + |
| 914 | + if not self.cudagraph_capture_sizes or multiple_of <= 1: |
| 915 | + return |
| 916 | + |
| 917 | + assert self.max_cudagraph_capture_size is not None |
| 918 | + rounded_sizes = sorted( |
| 919 | + set( |
| 920 | + round_up(size, multiple_of) |
| 921 | + for size in self.cudagraph_capture_sizes |
| 922 | + if round_up(size, multiple_of) <= self.max_cudagraph_capture_size |
| 923 | + ) |
| 924 | + ) |
| 925 | + |
| 926 | + if len(rounded_sizes) == 0: |
| 927 | + logger.warning( |
| 928 | + "No valid cudagraph sizes after rounding to multiple of " |
| 929 | + " num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens" |
| 930 | + " or max_cudagraph_capture_size (or cudagraph_capture_sizes)", |
| 931 | + multiple_of, |
| 932 | + ) |
| 933 | + return |
| 934 | + |
| 935 | + self.max_cudagraph_capture_size = rounded_sizes[-1] |
| 936 | + self.cudagraph_capture_sizes = rounded_sizes |
| 937 | + |
| 938 | + # Recompute after adjusting the cudagraph sizes |
| 939 | + self.compute_bs_to_padded_graph_size() |
| 940 | + |
| 941 | + def compute_bs_to_padded_graph_size(self): |
| 942 | + # pre-compute the mapping from batch size to padded graph size |
| 943 | + self.bs_to_padded_graph_size = [ |
| 944 | + 0 for i in range(self.max_cudagraph_capture_size + 1) |
| 945 | + ] |
| 946 | + for end, start in zip( |
| 947 | + self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], |
| 948 | + [0] + self.cudagraph_capture_sizes, |
| 949 | + ): |
| 950 | + for bs in range(start, end): |
| 951 | + if bs == start: |
| 952 | + self.bs_to_padded_graph_size[bs] = start |
| 953 | + else: |
| 954 | + self.bs_to_padded_graph_size[bs] = end |
0 commit comments