Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8126fa1
remove prefill support from 3d kernel
jvlunteren Nov 3, 2025
4a54c08
formatting
jvlunteren Nov 3, 2025
f1f58cc
Merge branch 'main' into jvl-triton-attn-upd1
jvlunteren Nov 3, 2025
84c5cd7
adapt 3D kernel for full CUDA Graph support
jvlunteren Nov 7, 2025
39f52b4
formatting
jvlunteren Nov 7, 2025
3102959
update unit test
jvlunteren Nov 7, 2025
9bcc1fb
corrected comment
jvlunteren Nov 7, 2025
f3fdb32
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 10, 2025
53d7b8b
added check for empty cudagraph_capture_sizes
jvlunteren Nov 10, 2025
a70bf68
allocate softmax buffers with padded head dimension
jvlunteren Nov 10, 2025
96576d8
Merge branch 'main' into jvl-triton-attn-upd2
jvlunteren Nov 11, 2025
a62aa11
fix failing ruff check
jvlunteren Nov 11, 2025
94e5ff1
replace estimation of q block counts by exact calculation
jvlunteren Nov 12, 2025
ec98ba0
Merge branch 'vllm-project:main' into jvl-triton-attn-upd3
jvlunteren Nov 12, 2025
349dd09
Merge branch 'vllm-project:main' into jvl-triton-attn-upd3
jvlunteren Nov 12, 2025
902f988
Merge branch 'main' into jvl-triton-attn-upd3
jvlunteren Nov 13, 2025
020e207
removed dependencies on other PRs
jvlunteren Nov 13, 2025
172487f
formatting
jvlunteren Nov 13, 2025
303def0
Merge branch 'main' into jvl-triton-attn-upd3
jvlunteren Nov 21, 2025
6377085
formatting
jvlunteren Nov 21, 2025
02a2f9a
Merge branch 'main' into jvl-triton-attn-upd3
jvlunteren Nov 21, 2025
cdfaee7
Merge branch 'main' into jvl-triton-attn-upd3
jvlunteren Nov 27, 2025
cc11f8f
remove redundant checkes, use math_utils
jvlunteren Nov 27, 2025
1f44588
made new unified_attention() arguments optional to preserve backward …
jvlunteren Nov 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/kernels/attention/test_triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]

# 0: use 2D kernel for decode
# 8: use 3D kernel for decode
SEQ_THRESHOLD_3D_VALUES = [0, 8]


def ref_paged_attn(
query: torch.Tensor,
Expand Down Expand Up @@ -92,6 +96,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
@torch.inference_mode()
def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
Expand All @@ -103,6 +108,7 @@ def test_triton_unified_attn(
soft_cap: float | None,
num_blocks: int,
q_dtype: torch.dtype | None,
seq_threshold_3D: int,
) -> None:
torch.set_default_device("cuda")

Expand Down Expand Up @@ -152,6 +158,28 @@ def test_triton_unified_attn(
k_descale = torch.rand(scale_shape, dtype=torch.float32)
v_descale = torch.rand(scale_shape, dtype=torch.float32)

num_queries_per_kv = num_query_heads // num_kv_heads
BLOCK_M = (
16 if num_queries_per_kv <= 16 else 1 << (num_queries_per_kv - 1).bit_length()
) # next power of 2 value
BLOCK_Q = BLOCK_M // num_queries_per_kv

block_q_seq_boundaries_tensor = torch.empty(num_seqs + 1, dtype=torch.int32)
if max_query_len > 1:
block_q_seq_boundaries_tensor[0] = 0
block_q_seq_boundaries_tensor[1 : cu_query_lens.numel()].copy_(
cu_query_lens[1:]
)
block_q_seq_boundaries_tensor[1 : cu_query_lens.numel()].sub_(
cu_query_lens[:-1]
)
block_q_seq_boundaries_tensor[1 : cu_query_lens.numel()].add_(BLOCK_Q - 1)
block_q_seq_boundaries_tensor[1 : cu_query_lens.numel()].floor_divide_(BLOCK_Q)
block_q_seq_boundaries_tensor[: cu_query_lens.numel()].cumsum_(dim=0)
num_q_blocks = block_q_seq_boundaries_tensor[cu_query_lens.numel() - 1]
else:
num_q_blocks = len(seq_lens)

unified_attention(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
Expand All @@ -169,6 +197,11 @@ def test_triton_unified_attn(
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
BLOCK_M=BLOCK_M,
BLOCK_Q=BLOCK_Q,
num_q_blocks=num_q_blocks,
block_q_seq_boundaries_tensor=block_q_seq_boundaries_tensor,
seq_threshold_3D=seq_threshold_3D,
)

ref_output = ref_paged_attn(
Expand Down
Loading