Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8126fa1
remove prefill support from 3d kernel
jvlunteren Nov 3, 2025
4a54c08
formatting
jvlunteren Nov 3, 2025
f1f58cc
Merge branch 'main' into jvl-triton-attn-upd1
jvlunteren Nov 3, 2025
84c5cd7
adapt 3D kernel for full CUDA Graph support
jvlunteren Nov 7, 2025
39f52b4
formatting
jvlunteren Nov 7, 2025
3102959
update unit test
jvlunteren Nov 7, 2025
9bcc1fb
corrected comment
jvlunteren Nov 7, 2025
f3fdb32
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 10, 2025
53d7b8b
added check for empty cudagraph_capture_sizes
jvlunteren Nov 10, 2025
a70bf68
allocate softmax buffers with padded head dimension
jvlunteren Nov 10, 2025
96576d8
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 11, 2025
a62aa11
fix failing ruff check
jvlunteren Nov 11, 2025
5d4921f
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 13, 2025
5a4173f
Merge branch 'vllm-project:main' into jvl-triton-attn-upd2
jvlunteren Nov 13, 2025
90e746a
remove dependencies on other PRs
jvlunteren Nov 13, 2025
5f67875
use math utility for computing next power of 2
jvlunteren Nov 17, 2025
721b319
add comment to explain threshold computation
jvlunteren Nov 17, 2025
acf43b8
use next_power_of_2 from vllm.utils.math_utils
jvlunteren Nov 21, 2025
e1b0a81
add assert to ensure capture sizes are set for CUDA Graphs
jvlunteren Nov 21, 2025
c214d7e
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 21, 2025
b0d42fa
remove superfluous check
jvlunteren Nov 24, 2025
c9a9aee
Update vllm/v1/attention/backends/triton_attn.py
jvlunteren Nov 25, 2025
92ea4a4
made new unified_attention() arguments optional to preserve backward …
jvlunteren Nov 25, 2025
c5e317c
updated comment
jvlunteren Nov 25, 2025
924d36e
make additonal new argument optional
jvlunteren Nov 25, 2025
633a319
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 27, 2025
2112153
bugfix and modification
jvlunteren Nov 27, 2025
741cf4e
Merge branch 'main' into jvl-triton-attn-upd2
tdoublep Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/kernels/attention/test_triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]

# 0: use 2D kernel for decode
# 8: use 3D kernel for decode
SEQ_THRESHOLD_3D_VALUES = [0, 8]


def ref_paged_attn(
query: torch.Tensor,
Expand Down Expand Up @@ -92,6 +96,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
@torch.inference_mode()
def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
Expand All @@ -103,6 +108,7 @@ def test_triton_unified_attn(
soft_cap: float | None,
num_blocks: int,
q_dtype: torch.dtype | None,
seq_threshold_3D: int,
) -> None:
torch.set_default_device("cuda")

Expand Down Expand Up @@ -152,6 +158,21 @@ def test_triton_unified_attn(
k_descale = torch.rand(scale_shape, dtype=torch.float32)
v_descale = torch.rand(scale_shape, dtype=torch.float32)

num_par_softmax_segments = 16
head_size_padded = 1 << (head_size - 1).bit_length() # next power of 2 value
softmax_segm_output = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
dtype=torch.float32,
)
softmax_segm_max = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
softmax_segm_expsum = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)

unified_attention(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
Expand All @@ -169,6 +190,11 @@ def test_triton_unified_attn(
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
)

ref_output = ref_paged_attn(
Expand Down
58 changes: 18 additions & 40 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
@triton.jit
def kernel_unified_attention_3d(
segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size]
# [num_tokens, num_query_heads, num_segments, head_size_padded]
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
query_ptr, # [num_tokens, num_query_heads, head_size]
Expand Down Expand Up @@ -749,6 +749,11 @@ def unified_attention(
q_descale,
k_descale,
v_descale,
seq_threshold_3D,
num_par_softmax_segments,
softmax_segm_output,
softmax_segm_max,
softmax_segm_expsum,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
Expand Down Expand Up @@ -793,8 +798,8 @@ def unified_attention(
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32

# if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
# if batch contains a prefill or number of sequences is larger than threshold
if max_seqlen_q > 1 or num_seqs > seq_threshold_3D:
kernel_unified_attention_2d[
(
total_num_q_blocks,
Expand Down Expand Up @@ -847,37 +852,10 @@ def unified_attention(
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS = 16

segm_output = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)

kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
kernel_unified_attention_3d[(num_seqs, num_kv_heads, num_par_softmax_segments)](
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
Expand Down Expand Up @@ -917,13 +895,13 @@ def unified_attention(
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
)
reduce_segments[(q.shape[0], num_query_heads)](
reduce_segments[(num_seqs, num_query_heads)](
output_ptr=out,
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
Expand All @@ -936,6 +914,6 @@ def unified_attention(
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
USE_FP8=output_scale is not None,
)
89 changes: 88 additions & 1 deletion vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
triton_reshape_and_cache_flash,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand All @@ -36,6 +36,11 @@
logger = init_logger(__name__)


# constants
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments


@dataclass
class TritonAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
Expand All @@ -54,6 +59,12 @@ class TritonAttentionMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor

seq_threshold_3D: int
num_par_softmax_segments: int
softmax_segm_output: torch.Tensor
softmax_segm_max: torch.Tensor
softmax_segm_expsum: torch.Tensor

# For cascade attention.
use_cascade: bool
common_prefix_len: int
Expand Down Expand Up @@ -87,6 +98,66 @@ def __init__(
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()

# Check if CUDA Graphs are enabled for decode
self.decode_cudagraph_enabled = (
self.vllm_config.compilation_config.cudagraph_mode
in (
CUDAGraphMode.FULL_AND_PIECEWISE,
CUDAGraphMode.FULL_DECODE_ONLY,
CUDAGraphMode.FULL,
)
)

# Set initial value for the threshold for the number of sequences used
# to select between the 2D and 3D kernels for decode.
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
if self.decode_cudagraph_enabled:
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
if not capture_sizes:
# If no CUDA Graph capture sizes are specified, the threshold
# is reset to zero, forcing the 2D kernel to be used.
self.seq_threshold_3D = 0
else:
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
upd_seq_threshold_3D = min(
capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)

# If the updated threshold becomes significantly larger than the
# initial value, it is reset to zero. This enforces the use of the
# 2D kernel only and ensures that the size of the allocated
# intermediate structures remains bounded.
if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D:
self.seq_threshold_3D = upd_seq_threshold_3D
else:
self.seq_threshold_3D = 0

self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
headdim_padded = 1 << (self.headdim - 1).bit_length() # next power of 2 value
self.softmax_segm_output = torch.empty(
(
self.seq_threshold_3D,
self.num_heads_q,
self.num_par_softmax_segments,
headdim_padded,
),
dtype=torch.float32,
device=device,
)
self.softmax_segm_max = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
self.softmax_segm_expsum = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
Expand Down Expand Up @@ -143,6 +214,11 @@ def build(
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
seq_threshold_3D=self.seq_threshold_3D,
num_par_softmax_segments=self.num_par_softmax_segments,
softmax_segm_output=self.softmax_segm_output,
softmax_segm_max=self.softmax_segm_max,
softmax_segm_expsum=self.softmax_segm_expsum,
)
return attn_metadata

Expand Down Expand Up @@ -346,6 +422,12 @@ def forward(
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table

seq_threshold_3D = attn_metadata.seq_threshold_3D
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
softmax_segm_output = attn_metadata.softmax_segm_output
softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

unified_attention(
Expand All @@ -366,6 +448,11 @@ def forward(
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
)
Expand Down