Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/kernels/attention/test_triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]

# 0: use 2D kernel for decode
# 8: use 3D kernel for decode
SEQ_THRESHOLD_3D_VALUES = [0, 8]


def ref_paged_attn(
query: torch.Tensor,
Expand Down Expand Up @@ -92,6 +96,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
@torch.inference_mode()
def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
Expand All @@ -103,6 +108,7 @@ def test_triton_unified_attn(
soft_cap: float | None,
num_blocks: int,
q_dtype: torch.dtype | None,
seq_threshold_3D: int,
) -> None:
torch.set_default_device("cuda")

Expand Down Expand Up @@ -152,6 +158,21 @@ def test_triton_unified_attn(
k_descale = torch.rand(scale_shape, dtype=torch.float32)
v_descale = torch.rand(scale_shape, dtype=torch.float32)

num_par_softmax_segments = 16
head_size_padded = 1 << (head_size - 1).bit_length() # next power of 2 value
softmax_segm_output = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
dtype=torch.float32,
)
softmax_segm_max = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
softmax_segm_expsum = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)

unified_attention(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
Expand All @@ -169,6 +190,11 @@ def test_triton_unified_attn(
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
)

ref_output = ref_paged_attn(
Expand Down
58 changes: 18 additions & 40 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
@triton.jit
def kernel_unified_attention_3d(
segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size]
# [num_tokens, num_query_heads, num_segments, head_size_padded]
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
query_ptr, # [num_tokens, num_query_heads, head_size]
Expand Down Expand Up @@ -749,6 +749,11 @@ def unified_attention(
q_descale,
k_descale,
v_descale,
seq_threshold_3D,
num_par_softmax_segments,
softmax_segm_output,
softmax_segm_max,
softmax_segm_expsum,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
Expand Down Expand Up @@ -793,8 +798,8 @@ def unified_attention(
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32

# if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
# if batch contains a prefill or number of sequences is larger than threshold
if max_seqlen_q > 1 or num_seqs > seq_threshold_3D:
kernel_unified_attention_2d[
(
total_num_q_blocks,
Expand Down Expand Up @@ -847,37 +852,10 @@ def unified_attention(
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS = 16

segm_output = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)

kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
kernel_unified_attention_3d[(num_seqs, num_kv_heads, num_par_softmax_segments)](
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
Expand Down Expand Up @@ -917,13 +895,13 @@ def unified_attention(
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
)
reduce_segments[(q.shape[0], num_query_heads)](
reduce_segments[(num_seqs, num_query_heads)](
output_ptr=out,
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
Expand All @@ -936,6 +914,6 @@ def unified_attention(
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
USE_FP8=output_scale is not None,
)
96 changes: 95 additions & 1 deletion vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
triton_reshape_and_cache_flash,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand All @@ -26,6 +26,7 @@
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
Expand All @@ -36,6 +37,11 @@
logger = init_logger(__name__)


# constants
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments


@dataclass
class TritonAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
Expand All @@ -54,6 +60,12 @@ class TritonAttentionMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor

seq_threshold_3D: int
num_par_softmax_segments: int
softmax_segm_output: torch.Tensor
softmax_segm_max: torch.Tensor
softmax_segm_expsum: torch.Tensor

# For cascade attention.
use_cascade: bool
common_prefix_len: int
Expand Down Expand Up @@ -87,6 +99,72 @@ def __init__(
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()

# Check if CUDA Graphs are enabled for decode
self.decode_cudagraph_enabled = (
self.vllm_config.compilation_config.cudagraph_mode
in (
CUDAGraphMode.FULL_AND_PIECEWISE,
CUDAGraphMode.FULL_DECODE_ONLY,
CUDAGraphMode.FULL,
)
)

# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
# A lower bound for num_q_blocks is the number of sequences.
# To ensure the minimum launch grid size is achieved, the number of sequences
# must be at least equal to the threshold below.
# If this threshold is not reached (i.e., the batch size is not large enough),
# the 3D kernel will be selected instead.
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv

# Modify the threshold if needed.
if self.decode_cudagraph_enabled:
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
if not capture_sizes:
# If no CUDA Graph capture sizes are specified, the threshold
# is reset to zero, forcing the 2D kernel to be used.
Comment on lines +124 to +125
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean to use CUDA graphs with no capture sizes? Is this a case that can actually happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this can happen (see above: #28306 (comment)).
I included those lines just to be sure.

self.seq_threshold_3D = 0
else:
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
upd_seq_threshold_3D = min(
capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)
Comment on lines +128 to +134
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if the threshold isn't one of the capture sizes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s consider an example where the threshold is 12 and the closest capture sizes are 8 and 16. In this case, the CUDA Graph associated with capture size 8 will handle the 3D kernel, while the one for capture size 16 will handle the 2D kernel. If, during normal operation, a batch arrives with a size between 9 and 16, it will be processed by replaying the CUDA Graph for the 2D kernel. Notice that the threshold no longer influences this decision: batches in that range, whether below or above the threshold, are all processed using the 2D kernel graph. Based on this observation, I found it cleaner to align the threshold with an actual capture size so that batches with sizes up to and including the threshold are handled by the 3D kernel.

At the same time, not aligning the threshold with a capture size does not affect output correctness. Therefore, the additional complexity of determining the exact threshold based on capture sizes could be eliminated.


# If the updated threshold becomes significantly larger than the
# initial value, it is reset to zero. This enforces the use of the
# 2D kernel only and ensures that the size of the allocated
# intermediate structures remains bounded.
if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D:
self.seq_threshold_3D = upd_seq_threshold_3D
else:
self.seq_threshold_3D = 0
Comment on lines +136 to +143
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How big can the intermediate data structures really get?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data structures have the following dimensions:

  • softmax_segm_output: [seq_threshold_3D, num_heads_q, num_par_softmax_segments, headdim_padded]
  • softmax_segm_max: [seq_threshold_3D, num_heads_q, num_par_softmax_segments]
  • softmax_segm_expsum: [seq_threshold_3D, num_heads_q, num_par_softmax_segments]

Example:
seq_threshold_3D = 16, num_heads_q=32, num_par_softmax_segments=16, headdim_padded=128

resulting sizes:

softmax_segm_output: 4MB
softmax_segm_max: 32KB
softmax_segm_expsum: 32KB


self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
headdim_padded = next_power_of_2(self.headdim)
self.softmax_segm_output = torch.empty(
(
self.seq_threshold_3D,
self.num_heads_q,
self.num_par_softmax_segments,
headdim_padded,
),
dtype=torch.float32,
device=device,
)
self.softmax_segm_max = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
self.softmax_segm_expsum = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
Expand Down Expand Up @@ -143,6 +221,11 @@ def build(
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
seq_threshold_3D=self.seq_threshold_3D,
num_par_softmax_segments=self.num_par_softmax_segments,
softmax_segm_output=self.softmax_segm_output,
softmax_segm_max=self.softmax_segm_max,
softmax_segm_expsum=self.softmax_segm_expsum,
)
return attn_metadata

Expand Down Expand Up @@ -346,6 +429,12 @@ def forward(
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table

seq_threshold_3D = attn_metadata.seq_threshold_3D
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
softmax_segm_output = attn_metadata.softmax_segm_output
softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

unified_attention(
Expand All @@ -366,6 +455,11 @@ def forward(
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
)
Expand Down