Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8126fa1
remove prefill support from 3d kernel
jvlunteren Nov 3, 2025
4a54c08
formatting
jvlunteren Nov 3, 2025
f1f58cc
Merge branch 'main' into jvl-triton-attn-upd1
jvlunteren Nov 3, 2025
84c5cd7
adapt 3D kernel for full CUDA Graph support
jvlunteren Nov 7, 2025
39f52b4
formatting
jvlunteren Nov 7, 2025
3102959
update unit test
jvlunteren Nov 7, 2025
9bcc1fb
corrected comment
jvlunteren Nov 7, 2025
f3fdb32
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 10, 2025
53d7b8b
added check for empty cudagraph_capture_sizes
jvlunteren Nov 10, 2025
a70bf68
allocate softmax buffers with padded head dimension
jvlunteren Nov 10, 2025
96576d8
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 11, 2025
a62aa11
fix failing ruff check
jvlunteren Nov 11, 2025
5d4921f
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 13, 2025
5a4173f
Merge branch 'vllm-project:main' into jvl-triton-attn-upd2
jvlunteren Nov 13, 2025
90e746a
remove dependencies on other PRs
jvlunteren Nov 13, 2025
5f67875
use math utility for computing next power of 2
jvlunteren Nov 17, 2025
721b319
add comment to explain threshold computation
jvlunteren Nov 17, 2025
acf43b8
use next_power_of_2 from vllm.utils.math_utils
jvlunteren Nov 21, 2025
e1b0a81
add assert to ensure capture sizes are set for CUDA Graphs
jvlunteren Nov 21, 2025
c214d7e
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 21, 2025
b0d42fa
remove superfluous check
jvlunteren Nov 24, 2025
c9a9aee
Update vllm/v1/attention/backends/triton_attn.py
jvlunteren Nov 25, 2025
92ea4a4
made new unified_attention() arguments optional to preserve backward …
jvlunteren Nov 25, 2025
c5e317c
updated comment
jvlunteren Nov 25, 2025
924d36e
make additonal new argument optional
jvlunteren Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/kernels/attention/test_triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]

# 0: use 2D kernel for decode
# 8: use 3D kernel for decode
SEQ_THRESHOLD_3D_VALUES = [0, 8]


def ref_paged_attn(
query: torch.Tensor,
Expand Down Expand Up @@ -92,6 +96,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
@torch.inference_mode()
def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
Expand All @@ -103,6 +108,7 @@ def test_triton_unified_attn(
soft_cap: float | None,
num_blocks: int,
q_dtype: torch.dtype | None,
seq_threshold_3D: int,
) -> None:
torch.set_default_device("cuda")

Expand Down Expand Up @@ -152,6 +158,20 @@ def test_triton_unified_attn(
k_descale = torch.rand(scale_shape, dtype=torch.float32)
v_descale = torch.rand(scale_shape, dtype=torch.float32)

num_par_softmax_segments = 16
softmax_segm_output = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size),
dtype=torch.float32,
)
softmax_segm_max = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
softmax_segm_expsum = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)

unified_attention(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
Expand All @@ -169,6 +189,11 @@ def test_triton_unified_attn(
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
)

ref_output = ref_paged_attn(
Expand Down
133 changes: 30 additions & 103 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@ def find_seq_idx(
target_idx,
num_seqs,
BLOCK_Q: tl.constexpr,
use_q_block_mode: tl.constexpr,
):
left: tl.int32 = 0
right = num_seqs
while left < right:
mid = (left + right) // 2
val = tl.load(query_start_len_ptr + mid)
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
mid_val = val // BLOCK_Q + mid

if mid_val <= target_idx:
left = mid + 1
Expand Down Expand Up @@ -105,9 +104,7 @@ def kernel_unified_attention_2d(
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)

seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q)

q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx

Expand Down Expand Up @@ -393,32 +390,13 @@ def kernel_unified_attention_3d(
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
):
q_block_global_idx = tl.program_id(0)
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
segm_idx = tl.program_id(2)

seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)

q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx

q_block_local_idx = q_block_global_idx - q_block_start_idx

cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)

cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index

if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)

Expand All @@ -432,9 +410,9 @@ def kernel_unified_attention_3d(
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_pos = offs_m // num_queries_per_kv

query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_0 = seq_idx + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
query_offset = (
query_offset_0[:, None] * query_stride_0
Expand All @@ -443,7 +421,7 @@ def kernel_unified_attention_3d(
)

dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < 1, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)

# Q : (BLOCK_M, HEAD_SIZE_PADDED)
Expand Down Expand Up @@ -471,7 +449,7 @@ def kernel_unified_attention_3d(
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)

# context length for this particular sequences
context_len = seq_len - cur_batch_query_len
context_len = seq_len - 1

# alibi slope for this head
if USE_ALIBI_SLOPES:
Expand All @@ -485,31 +463,15 @@ def kernel_unified_attention_3d(
qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]

# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = (
context_len
+ q_block_local_idx * BLOCK_Q
+ (BLOCK_M - 1) // num_queries_per_kv
+ 1
)

# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)

# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
num_tiles = cdiv_fn(seq_len, TILE_SIZE)

# iterate through tiles within current segment
for j in range(
segm_idx * tiles_per_segment,
min((segm_idx + 1) * tiles_per_segment, num_tiles),
):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len
tile_mask = seq_offset < seq_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
Expand Down Expand Up @@ -650,7 +612,6 @@ def reduce_segments(
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr, # [num_seqs]
num_seqs, # int
num_query_heads: tl.constexpr, # int
out_scale_inv, # float32
output_stride_0: tl.int64, # int
Expand All @@ -659,20 +620,14 @@ def reduce_segments(
TILE_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
query_token_idx = tl.program_id(0)
seq_idx = tl.program_id(0)
query_head_idx = tl.program_id(1)

seq_idx = find_seq_idx(
query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False
)

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)

Expand All @@ -689,7 +644,7 @@ def reduce_segments(

# load segment maxima
segm_offset = (
query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
+ query_head_idx * NUM_SEGMENTS_PER_SEQ
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)
)
Expand All @@ -703,7 +658,7 @@ def reduce_segments(

# load, rescale, and add segment attention outputs
segm_output_offset = (
query_token_idx.to(tl.int64)
seq_idx.to(tl.int64)
* (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED
Expand All @@ -725,7 +680,7 @@ def reduce_segments(

# write result
output_offset = (
query_token_idx * output_stride_0
seq_idx * output_stride_0
+ query_head_idx * output_stride_1
+ tl.arange(0, HEAD_SIZE_PADDED)
)
Expand All @@ -749,6 +704,11 @@ def unified_attention(
q_descale,
k_descale,
v_descale,
seq_threshold_3D,
num_par_softmax_segments,
softmax_segm_output,
softmax_segm_max,
softmax_segm_expsum,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
Expand Down Expand Up @@ -793,8 +753,8 @@ def unified_attention(
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32

# if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
# if batch contains a prefill or number of sequences is larger than threshold
if max_seqlen_q > 1 or num_seqs > seq_threshold_3D:
kernel_unified_attention_2d[
(
total_num_q_blocks,
Expand Down Expand Up @@ -847,37 +807,10 @@ def unified_attention(
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS = 16

segm_output = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)

kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
kernel_unified_attention_3d[(num_seqs, num_kv_heads, num_par_softmax_segments)](
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
Expand Down Expand Up @@ -913,19 +846,15 @@ def unified_attention(
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
)
reduce_segments[(q.shape[0], num_query_heads)](
reduce_segments[(num_seqs, num_query_heads)](
output_ptr=out,
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
out_scale_inv=1 / output_scale if output_scale is not None else 1.0,
output_stride_0=out.stride(0),
Expand All @@ -934,8 +863,6 @@ def unified_attention(
TILE_SIZE=TILE_SIZE_DECODE,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
USE_FP8=output_scale is not None,
)
Loading