Skip to content

Conversation

@jvlunteren
Copy link
Contributor

@jvlunteren jvlunteren commented Nov 12, 2025

Purpose

This pull request depends on PR #28306.

This pull request introduces the following improvements to the 2D attention kernel:

  • Exact Q-block calculation: replace the previous estimation of the number of Q blocks per sequence and the overall Q-block count with an exact calculation.
  • Q-block to sequence mapping: replace the 2D kernel function that determines the sequence for each Q block with a version based on the new exact calculation.
  • Separate launch paths: add distinct code paths for launching the 2D kernel for mixed prefill/decode batches versus decode-only batches.
  • Kernel simplification for decode: adapt the 2D kernel to remove the execution of instructions required only for prefill when running decode attention.

Impact and benefits:

  • The first two changes reduce redundant Q blocks and eliminate unnecessary launches of Triton program instances that perform no actual work, improving GPU resource utilization.
  • Both rely on a data structure built by the attention metadata builder class, which amortizes the (small) computation cost across multiple layers. This also removes CPU–GPU synchronization in the attention kernel launch code, enabling CUDA Graph capture.
  • The last two changes allow separate tuning of the 2D kernel for mixed prefill/decode and decode-only batches, and also simplify CUDA Graph support.

The changes in this PR are not intended to deliver significant performance gains immediately, but rather to lay the groundwork for future optimizations in upcoming PRs.

Note that the original dependencies on other PRs were removed. Additionally, the first two modifications listed above (exact Q-block calculation and Q-block to sequence mapping) have been applied only to the 2D kernel, not the 3D kernel. The latter will be optimized using a different approach in an upcoming PR.

Test Plan

The unit test ./tests/kernels/attention/test_triton_unified_attention.py has been updated to reflect the above changes. Other tests, such as lm_eval, remain compatible and can be used without modification.

Test Result

unit test results for updated Triton unified attention kernel (this PR):

python3 -m pytest tests/kernels/attention/test_triton_unified_attention.py

================================================ 512 passed in 155.00s (0:02:34) ================================================


lm_eval results for updated Triton unified attention kernel (this PR):

VLLM_ATTENTION_BACKEND=TRITON_ATTN lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.786|±  |0.0184|

yields similar lm_eval results as FlashAttention:

lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.794|±  |0.0181|
|     |       |strict-match    |     5|exact_match|↑  |0.772|±  |0.0188|

@tdoublep @bringlein

jvlunteren and others added 15 commits November 3, 2025 11:28
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant improvements to the 2D Triton attention kernel, focusing on performance and resource utilization. The changes include exact Q-block calculation, improved Q-block to sequence mapping, separate launch paths for prefill and decode, and kernel simplification for decode-only batches.

My review of the changes indicates that:

  • The logic for exact Q-block calculation has been correctly moved to the TritonAttentionMetadataBuilder, which is more efficient as it's computed on the CPU once per batch.
  • The Triton kernels (kernel_unified_attention_2d, kernel_unified_attention_3d, reduce_segments) have been simplified and refactored to support the new launch strategies, especially for decode-only paths. This makes the code cleaner and more specialized.
  • The introduction of seq_threshold_3D allows for dynamic selection between 2D and 3D kernels for decode, with thoughtful consideration for CUDA graph capture.
  • The test suite has been appropriately updated to cover the new functionality and parameters.

Overall, the changes are well-structured, thoroughly implemented, and align with the goal of laying a foundation for future optimizations. The code quality is high, and I did not find any critical or high-severity issues. This is an excellent contribution.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 708 to 719
q_descale,
k_descale,
v_descale,
BLOCK_M,
BLOCK_Q,
num_q_blocks,
block_q_seq_boundaries_tensor,
seq_threshold_3D,
num_par_softmax_segments,
softmax_segm_output,
softmax_segm_max,
softmax_segm_expsum,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Update unified_attention call sites for new required args

The unified_attention helper now requires additional positional arguments (BLOCK_M, BLOCK_Q, num_q_blocks, block_q_seq_boundaries_tensor, and the softmax scratch buffers). Only the Triton backend was updated to pass them, but other callers such as vllm/v1/attention/backends/tree_attn.py, vllm/v1/attention/backends/xformers.py, and vllm/v1/attention/backends/rocm_aiter_unified_attn.py still invoke this function with the old signature. Any model using those backends will now raise TypeError: unified_attention() missing required positional argument … as soon as attention is evaluated. Either provide defaults or update all call sites; otherwise these backends are unusable after this change.

Useful? React with 👍 / 👎.

jvlunteren and others added 3 commits November 13, 2025 09:19
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@heheda12345
Copy link
Collaborator

CC @tdoublep

jvlunteren and others added 3 commits November 21, 2025 16:20
Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants