-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[Kernel] Support CUDA Graphs in 3D Triton Attention Kernel #28306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8126fa1
4a54c08
f1f58cc
84c5cd7
39f52b4
3102959
9bcc1fb
f3fdb32
53d7b8b
a70bf68
96576d8
a62aa11
5d4921f
5a4173f
90e746a
5f67875
721b319
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| triton_reshape_and_cache_flash, | ||
| ) | ||
| from vllm.attention.ops.triton_unified_attention import unified_attention | ||
| from vllm.config import VllmConfig | ||
| from vllm.config import CUDAGraphMode, VllmConfig | ||
| from vllm.config.cache import CacheDType | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
|
|
@@ -26,6 +26,7 @@ | |
| ) | ||
| from vllm.platforms import current_platform | ||
| from vllm.platforms.interface import DeviceCapability | ||
| from vllm.utils.math_utils import next_power_of_2 | ||
| from vllm.v1.attention.backends.utils import ( | ||
| AttentionCGSupport, | ||
| AttentionMetadataBuilder, | ||
|
|
@@ -36,6 +37,11 @@ | |
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| # constants | ||
| MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel | ||
| NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments | ||
|
|
||
|
|
||
| @dataclass | ||
| class TritonAttentionMetadata: | ||
| # NOTE(sang): Definition of context_len, query_len, and seq_len. | ||
|
|
@@ -54,6 +60,12 @@ class TritonAttentionMetadata: | |
| block_table: torch.Tensor | ||
| slot_mapping: torch.Tensor | ||
|
|
||
| seq_threshold_3D: int | ||
| num_par_softmax_segments: int | ||
| softmax_segm_output: torch.Tensor | ||
| softmax_segm_max: torch.Tensor | ||
| softmax_segm_expsum: torch.Tensor | ||
|
|
||
| # For cascade attention. | ||
| use_cascade: bool | ||
| common_prefix_len: int | ||
|
|
@@ -87,6 +99,72 @@ def __init__( | |
| self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) | ||
| self.headdim = model_config.get_head_size() | ||
|
|
||
| # Check if CUDA Graphs are enabled for decode | ||
| self.decode_cudagraph_enabled = ( | ||
| self.vllm_config.compilation_config.cudagraph_mode | ||
| in ( | ||
| CUDAGraphMode.FULL_AND_PIECEWISE, | ||
| CUDAGraphMode.FULL_DECODE_ONLY, | ||
| CUDAGraphMode.FULL, | ||
| ) | ||
| ) | ||
|
|
||
| # The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv). | ||
| # A lower bound for num_q_blocks is the number of sequences. | ||
| # To ensure the minimum launch grid size is achieved, the number of sequences | ||
| # must be at least equal to the threshold below. | ||
| # If this threshold is not reached (i.e., the batch size is not large enough), | ||
| # the 3D kernel will be selected instead. | ||
| self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv | ||
|
|
||
| # Modify the threshold if needed. | ||
| if self.decode_cudagraph_enabled: | ||
| capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes | ||
| if not capture_sizes: | ||
| # If no CUDA Graph capture sizes are specified, the threshold | ||
| # is reset to zero, forcing the 2D kernel to be used. | ||
| self.seq_threshold_3D = 0 | ||
| else: | ||
| # Select the CUDA Graph capture size closest to self.seq_threshold_3D | ||
| # as threshold. This ensures that each captured graph covers the | ||
| # correct execution path. | ||
| upd_seq_threshold_3D = min( | ||
| capture_sizes, | ||
| key=lambda x: abs(x - self.seq_threshold_3D), | ||
| ) | ||
|
Comment on lines
+128
to
+134
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would happen if the threshold isn't one of the capture sizes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let’s consider an example where the threshold is 12 and the closest capture sizes are 8 and 16. In this case, the CUDA Graph associated with capture size 8 will handle the 3D kernel, while the one for capture size 16 will handle the 2D kernel. If, during normal operation, a batch arrives with a size between 9 and 16, it will be processed by replaying the CUDA Graph for the 2D kernel. Notice that the threshold no longer influences this decision: batches in that range, whether below or above the threshold, are all processed using the 2D kernel graph. Based on this observation, I found it cleaner to align the threshold with an actual capture size so that batches with sizes up to and including the threshold are handled by the 3D kernel. At the same time, not aligning the threshold with a capture size does not affect output correctness. Therefore, the additional complexity of determining the exact threshold based on capture sizes could be eliminated. |
||
|
|
||
| # If the updated threshold becomes significantly larger than the | ||
| # initial value, it is reset to zero. This enforces the use of the | ||
| # 2D kernel only and ensures that the size of the allocated | ||
| # intermediate structures remains bounded. | ||
| if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: | ||
| self.seq_threshold_3D = upd_seq_threshold_3D | ||
| else: | ||
| self.seq_threshold_3D = 0 | ||
|
Comment on lines
+136
to
+143
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How big can the intermediate data structures really get?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The data structures have the following dimensions:
Example: resulting sizes:
|
||
|
|
||
| self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS | ||
| headdim_padded = next_power_of_2(self.headdim) | ||
| self.softmax_segm_output = torch.empty( | ||
| ( | ||
| self.seq_threshold_3D, | ||
| self.num_heads_q, | ||
| self.num_par_softmax_segments, | ||
| headdim_padded, | ||
| ), | ||
| dtype=torch.float32, | ||
| device=device, | ||
| ) | ||
| self.softmax_segm_max = torch.empty( | ||
| (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), | ||
| dtype=torch.float32, | ||
| device=device, | ||
| ) | ||
| self.softmax_segm_expsum = torch.empty( | ||
| (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), | ||
| dtype=torch.float32, | ||
| device=device, | ||
| ) | ||
|
|
||
| def build_for_cudagraph_capture( | ||
| self, common_attn_metadata: CommonAttentionMetadata | ||
| ) -> TritonAttentionMetadata: | ||
|
|
@@ -143,6 +221,11 @@ def build( | |
| prefix_kv_lens=prefix_kv_lens, | ||
| suffix_kv_lens=suffix_kv_lens, | ||
| prefix_scheduler_metadata=prefix_scheduler_metadata, | ||
| seq_threshold_3D=self.seq_threshold_3D, | ||
| num_par_softmax_segments=self.num_par_softmax_segments, | ||
| softmax_segm_output=self.softmax_segm_output, | ||
| softmax_segm_max=self.softmax_segm_max, | ||
| softmax_segm_expsum=self.softmax_segm_expsum, | ||
| ) | ||
| return attn_metadata | ||
|
|
||
|
|
@@ -346,6 +429,12 @@ def forward( | |
| max_seqlen_k = attn_metadata.max_seq_len | ||
| block_table = attn_metadata.block_table | ||
|
|
||
| seq_threshold_3D = attn_metadata.seq_threshold_3D | ||
| num_par_softmax_segments = attn_metadata.num_par_softmax_segments | ||
| softmax_segm_output = attn_metadata.softmax_segm_output | ||
| softmax_segm_max = attn_metadata.softmax_segm_max | ||
| softmax_segm_expsum = attn_metadata.softmax_segm_expsum | ||
|
|
||
| descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) | ||
|
|
||
| unified_attention( | ||
|
|
@@ -366,6 +455,11 @@ def forward( | |
| q_descale=None, # Not supported | ||
| k_descale=layer._k_scale.expand(descale_shape), | ||
| v_descale=layer._v_scale.expand(descale_shape), | ||
| seq_threshold_3D=seq_threshold_3D, | ||
| num_par_softmax_segments=num_par_softmax_segments, | ||
| softmax_segm_output=softmax_segm_output, | ||
| softmax_segm_max=softmax_segm_max, | ||
| softmax_segm_expsum=softmax_segm_expsum, | ||
| sinks=self.sinks, | ||
| output_scale=output_scale, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it mean to use CUDA graphs with no capture sizes? Is this a case that can actually happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this can happen (see above: #28306 (comment)).
I included those lines just to be sure.