-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Kernel] Improve 2D Triton Attention Kernel #28576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant improvements to the 2D Triton attention kernel, focusing on performance and resource utilization. The changes include exact Q-block calculation, improved Q-block to sequence mapping, separate launch paths for prefill and decode, and kernel simplification for decode-only batches.
My review of the changes indicates that:
- The logic for exact Q-block calculation has been correctly moved to the
TritonAttentionMetadataBuilder, which is more efficient as it's computed on the CPU once per batch. - The Triton kernels (
kernel_unified_attention_2d,kernel_unified_attention_3d,reduce_segments) have been simplified and refactored to support the new launch strategies, especially for decode-only paths. This makes the code cleaner and more specialized. - The introduction of
seq_threshold_3Dallows for dynamic selection between 2D and 3D kernels for decode, with thoughtful consideration for CUDA graph capture. - The test suite has been appropriately updated to cover the new functionality and parameters.
Overall, the changes are well-structured, thoroughly implemented, and align with the goal of laying a foundation for future optimizations. The code quality is high, and I did not find any critical or high-severity issues. This is an excellent contribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| q_descale, | ||
| k_descale, | ||
| v_descale, | ||
| BLOCK_M, | ||
| BLOCK_Q, | ||
| num_q_blocks, | ||
| block_q_seq_boundaries_tensor, | ||
| seq_threshold_3D, | ||
| num_par_softmax_segments, | ||
| softmax_segm_output, | ||
| softmax_segm_max, | ||
| softmax_segm_expsum, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update unified_attention call sites for new required args
The unified_attention helper now requires additional positional arguments (BLOCK_M, BLOCK_Q, num_q_blocks, block_q_seq_boundaries_tensor, and the softmax scratch buffers). Only the Triton backend was updated to pass them, but other callers such as vllm/v1/attention/backends/tree_attn.py, vllm/v1/attention/backends/xformers.py, and vllm/v1/attention/backends/rocm_aiter_unified_attn.py still invoke this function with the old signature. Any model using those backends will now raise TypeError: unified_attention() missing required positional argument … as soon as attention is evaluated. Either provide defaults or update all call sites; otherwise these backends are unusable after this change.
Useful? React with 👍 / 👎.
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
|
CC @tdoublep |
Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Purpose
This pull request depends on PR #28306.This pull request introduces the following improvements to the 2D attention kernel:
Impact and benefits:
The changes in this PR are not intended to deliver significant performance gains immediately, but rather to lay the groundwork for future optimizations in upcoming PRs.
Note that the original dependencies on other PRs were removed. Additionally, the first two modifications listed above (exact Q-block calculation and Q-block to sequence mapping) have been applied only to the 2D kernel, not the 3D kernel. The latter will be optimized using a different approach in an upcoming PR.
Test Plan
The unit test
./tests/kernels/attention/test_triton_unified_attention.pyhas been updated to reflect the above changes. Other tests, such aslm_eval, remain compatible and can be used without modification.Test Result
unit test results for updated Triton unified attention kernel (this PR):
lm_evalresults for updated Triton unified attention kernel (this PR):yields similar
lm_evalresults as FlashAttention:@tdoublep @bringlein