|
18 | 18 | from vllm.logger import init_logger |
19 | 19 | from vllm.platforms import current_platform |
20 | 20 | from vllm.utils.import_utils import resolve_obj_by_qualname |
| 21 | +from vllm.utils.math_utils import round_up |
21 | 22 | from vllm.utils.torch_utils import is_torch_equal_or_newer |
22 | 23 |
|
23 | 24 | if TYPE_CHECKING: |
@@ -773,19 +774,8 @@ def post_init_cudagraph_sizes(self) -> None: |
773 | 774 | if self.cudagraph_capture_sizes: |
774 | 775 | assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size |
775 | 776 |
|
776 | | - # pre-compute the mapping from batch size to padded graph size |
777 | | - self.bs_to_padded_graph_size = [ |
778 | | - 0 for i in range(self.max_cudagraph_capture_size + 1) |
779 | | - ] |
780 | | - for end, start in zip( |
781 | | - self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], |
782 | | - [0] + self.cudagraph_capture_sizes, |
783 | | - ): |
784 | | - for bs in range(start, end): |
785 | | - if bs == start: |
786 | | - self.bs_to_padded_graph_size[bs] = start |
787 | | - else: |
788 | | - self.bs_to_padded_graph_size[bs] = end |
| 777 | + # May get recomputed in the model runner if adjustment is needed for spec-decode |
| 778 | + self.compute_bs_to_padded_graph_size() |
789 | 779 |
|
790 | 780 | def set_splitting_ops_for_v1(self): |
791 | 781 | # NOTE: this function needs to be called only when mode is |
@@ -922,3 +912,64 @@ def custom_op_log_check(self): |
922 | 912 | enable_str, |
923 | 913 | op, |
924 | 914 | ) |
| 915 | + |
| 916 | + def adjust_cudagraph_sizes_for_spec_decode( |
| 917 | + self, uniform_decode_query_len: int, tensor_parallel_size: int |
| 918 | + ): |
| 919 | + multiple_of = uniform_decode_query_len |
| 920 | + if tensor_parallel_size > 1: |
| 921 | + multiple_of = max(uniform_decode_query_len, tensor_parallel_size) |
| 922 | + if ( |
| 923 | + multiple_of % uniform_decode_query_len != 0 |
| 924 | + or multiple_of % tensor_parallel_size != 0 |
| 925 | + ): |
| 926 | + raise ValueError( |
| 927 | + f"Can't determine cudagraph shapes that are both a " |
| 928 | + f"multiple of {uniform_decode_query_len} " |
| 929 | + f"(num_speculative_tokens + 1) required by spec-decode " |
| 930 | + f"and {tensor_parallel_size} (tensor_parallel_size) " |
| 931 | + f"required by sequence parallelism please adjust " |
| 932 | + f"num_speculative_tokens or disable sequence parallelism" |
| 933 | + ) |
| 934 | + |
| 935 | + if not self.cudagraph_capture_sizes or multiple_of <= 1: |
| 936 | + return |
| 937 | + |
| 938 | + assert self.max_cudagraph_capture_size is not None |
| 939 | + rounded_sizes = sorted( |
| 940 | + set( |
| 941 | + round_up(size, multiple_of) |
| 942 | + for size in self.cudagraph_capture_sizes |
| 943 | + if round_up(size, multiple_of) <= self.max_cudagraph_capture_size |
| 944 | + ) |
| 945 | + ) |
| 946 | + |
| 947 | + if len(rounded_sizes) == 0: |
| 948 | + logger.warning( |
| 949 | + "No valid cudagraph sizes after rounding to multiple of " |
| 950 | + " num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens" |
| 951 | + " or max_cudagraph_capture_size (or cudagraph_capture_sizes)", |
| 952 | + multiple_of, |
| 953 | + ) |
| 954 | + return |
| 955 | + |
| 956 | + self.max_cudagraph_capture_size = rounded_sizes[-1] |
| 957 | + self.cudagraph_capture_sizes = rounded_sizes |
| 958 | + |
| 959 | + # Recompute after adjusting the cudagraph sizes |
| 960 | + self.compute_bs_to_padded_graph_size() |
| 961 | + |
| 962 | + def compute_bs_to_padded_graph_size(self): |
| 963 | + # pre-compute the mapping from batch size to padded graph size |
| 964 | + self.bs_to_padded_graph_size = [ |
| 965 | + 0 for i in range(self.max_cudagraph_capture_size + 1) |
| 966 | + ] |
| 967 | + for end, start in zip( |
| 968 | + self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], |
| 969 | + [0] + self.cudagraph_capture_sizes, |
| 970 | + ): |
| 971 | + for bs in range(start, end): |
| 972 | + if bs == start: |
| 973 | + self.bs_to_padded_graph_size[bs] = start |
| 974 | + else: |
| 975 | + self.bs_to_padded_graph_size[bs] = end |
0 commit comments