Skip to content

Conversation

@jvlunteren
Copy link
Contributor

@jvlunteren jvlunteren commented Nov 7, 2025

Purpose

This pull request depends on PR #27993.

This pull request adapts the 3D Triton attention kernel, which is used exclusively for decode operations, to support full CUDA Graphs. The key changes include:

  • The allocation of the intermediate data structures used for the tiled softmax implementation has been moved to the attention metadata builder class.

  • The dynamic selection between the 2D and 3D attention kernels during decode is now based on comparing the batch size against a threshold corresponding to one of the CUDA Graph capture sizes. This ensures that, for each batch size, only one valid kernel choice (either 2D or 3D kernel) exists and will be captured correctly.

The updated kernel selection logic is only applied when CUDA Graphs are enabled for decoding. This is now automatically detected in the attention metadata builder, which sets the appropriate threshold values accordingly.

Test Plan

The unit test ./tests/kernels/attention/test_triton_unified_attention.py has been updated to include the allocation of intermediate data structures required for the tiled softmax implementation. It has also been modified to explicitly test the separate use of the 2D and 3D attention kernels during decoding. Other tests, such as lm_eval, remain compatible and can be used without modification.

Test Result

unit test results for updated Triton unified attention kernel (this PR):

python3 -m pytest tests/kernels/attention/test_triton_unified_attention.py

================================================ 512 passed in 71.05s (0:01:11) ================================================


lm_eval results for updated Triton unified attention kernel (this PR):

VLLM_ATTENTION_BACKEND=TRITON_ATTN lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.800|±  |0.0179|
|     |       |strict-match    |     5|exact_match|↑  |0.788|±  |0.0183|

yields similar lm_eval results as FlashAttention:

lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.790|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  |0.768|±  |0.0189|

@tdoublep @bringlein

jvlunteren and others added 7 commits November 3, 2025 11:28
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adapts the 3D Triton attention kernel for CUDA graph compatibility, which is a valuable improvement for performance. The core changes involve moving intermediate tensor allocations out of the kernel and aligning the 2D/3D kernel selection threshold with CUDA graph capture sizes. The refactoring of the Triton kernels for decode-only operations is clean and logical.

My review identifies one critical issue where an empty cudagraph_capture_sizes list would cause a ValueError and crash the server. I've provided a suggestion to handle this edge case gracefully. Otherwise, the changes are well-implemented and align with the stated goals.

Comment on lines 113 to 129
if self.decode_cudagraph_enabled:
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
upd_seq_threshold_3D = min(
self.vllm_config.compilation_config.cudagraph_capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)

# If the updated threshold becomes significantly larger than the
# initial value, it is reset to zero. This enforces the use of the
# 2D kernel only and ensures that the size of the allocated
# intermediate structures remains bounded.
if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D:
self.seq_threshold_3D = upd_seq_threshold_3D
else:
self.seq_threshold_3D = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code does not handle the case where self.vllm_config.compilation_config.cudagraph_capture_sizes is an empty list. If it is empty, min() will raise a ValueError, causing a server crash on startup. This can happen if a user configures CUDA graphs but provides an empty list for cudagraph_capture_sizes.

We should add a check to ensure cudagraph_capture_sizes is not empty before calling min(). If it is empty, we should fall back to a safe default, such as setting self.seq_threshold_3D = 0 to always use the 2D kernel, which is a safe choice for CUDA graph compatibility.

        if self.decode_cudagraph_enabled:
            capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
            if not capture_sizes:
                # If no CUDA graph capture sizes are specified, we cannot
                # guarantee a static kernel choice. Forcing the 2D kernel
                # is the safest option.
                self.seq_threshold_3D = 0
            else:
                # Select the CUDA Graph capture size closest to self.seq_threshold_3D
                # as threshold. This ensures that each captured graph covers the
                # correct execution path.
                upd_seq_threshold_3D = min(
                    capture_sizes,
                    key=lambda x: abs(x - self.seq_threshold_3D),
                )

                # If the updated threshold becomes significantly larger than the
                # initial value, it is reset to zero. This enforces the use of the
                # 2D kernel only and ensures that the size of the allocated
                # intermediate structures remains bounded.
                if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D:
                    self.seq_threshold_3D = upd_seq_threshold_3D
                else:
                    self.seq_threshold_3D = 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://docs.vllm.ai/en/latest/configuration/engine_args/#compilationconfig:

If cudagraph_capture_sizes is specified, this will be set to the largest size in that list (or checked for consistency if specified). If cudagraph_capture_sizes is not specified, the list of sizes is generated automatically following the pattern:

[1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_cudagraph_capture_size + 1, 16))

From this, I assumed that cudagraph_capture_sizes is not empty when CUDA Graphs are enabled. However, to be sure, I have implemented the modification as suggested above.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 131 to 138
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
self.softmax_segm_output = torch.empty(
(
self.seq_threshold_3D,
self.num_heads_q,
self.num_par_softmax_segments,
self.headdim,
),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preallocate softmax buffers with padded head dimension

The 3D decode kernel writes per-head outputs with a width of triton.next_power_of_2(head_size) (see HEAD_SIZE_PADDED usages in kernel_unified_attention_3d and reduce_segments). However, the metadata builder now preallocates softmax_segm_output with the unpadded self.headdim (lines 131‑137). For models whose head size is not already a power of two (e.g., 80 or 96), the kernel will index past the end of these buffers, corrupting memory or producing incorrect results whenever the 3D path is selected. The buffers should match the padded size used by the kernels.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved this issue by modifying code and unit test to allocate softmax buffers based on padded head dimension.

@jvlunteren jvlunteren changed the title Jvl triton attn upd2 [Kernel] Support CUDA Graphs in 3D Triton Attention Kernel Nov 7, 2025
jvlunteren and others added 3 commits November 10, 2025 08:45
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@heheda12345
Copy link
Collaborator

CC @LucasWilkinson @tlrmchlsmth

@mergify
Copy link

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jvlunteren.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2025
Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com>
@mergify mergify bot removed the needs-rebase label Nov 11, 2025
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Comment on lines +117 to +118
# If no CUDA Graph capture sizes are specified, the threshold
# is reset to zero, forcing the 2D kernel to be used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean to use CUDA graphs with no capture sizes? Is this a case that can actually happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this can happen (see above: #28306 (comment)).
I included those lines just to be sure.

self.seq_threshold_3D = 0

self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
headdim_padded = 1 << (self.headdim - 1).bit_length() # next power of 2 value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc there is a function like triton.next_power_of_2 that might be a bit easier to understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not aware of such a function. For clarity, I will add custom function to achieve this, based on the above code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that this function is defined in vllm.utils.match_utils. I have modified the code to use that function.

Comment on lines +129 to +136
# If the updated threshold becomes significantly larger than the
# initial value, it is reset to zero. This enforces the use of the
# 2D kernel only and ensures that the size of the allocated
# intermediate structures remains bounded.
if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D:
self.seq_threshold_3D = upd_seq_threshold_3D
else:
self.seq_threshold_3D = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How big can the intermediate data structures really get?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data structures have the following dimensions:

  • softmax_segm_output: [seq_threshold_3D, num_heads_q, num_par_softmax_segments, headdim_padded]
  • softmax_segm_max: [seq_threshold_3D, num_heads_q, num_par_softmax_segments]
  • softmax_segm_expsum: [seq_threshold_3D, num_heads_q, num_par_softmax_segments]

Example:
seq_threshold_3D = 16, num_heads_q=32, num_par_softmax_segments=16, headdim_padded=128

resulting sizes:

softmax_segm_output: 4MB
softmax_segm_max: 32KB
softmax_segm_expsum: 32KB

Comment on lines +121 to +127
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
upd_seq_threshold_3D = min(
capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if the threshold isn't one of the capture sizes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s consider an example where the threshold is 12 and the closest capture sizes are 8 and 16. In this case, the CUDA Graph associated with capture size 8 will handle the 3D kernel, while the one for capture size 16 will handle the 2D kernel. If, during normal operation, a batch arrives with a size between 9 and 16, it will be processed by replaying the CUDA Graph for the 2D kernel. Notice that the threshold no longer influences this decision: batches in that range, whether below or above the threshold, are all processed using the 2D kernel graph. Based on this observation, I found it cleaner to align the threshold with an actual capture size so that batches with sizes up to and including the threshold are handled by the 3D kernel.

At the same time, not aligning the threshold with a capture size does not affect output correctness. Therefore, the additional complexity of determining the exact threshold based on capture sizes could be eliminated.

Comment on lines 111 to 113
# Set initial value for the threshold for the number of sequences used
# to select between the 2D and 3D kernels for decode.
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might good to explain why we use this formula to decide the target threshold

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do that.

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants