diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index bf4d2179af5f..aaac33be0de2 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -22,6 +22,10 @@ # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] +# 0: use 2D kernel for decode +# 8: use 3D kernel for decode +SEQ_THRESHOLD_3D_VALUES = [0, 8] + def ref_paged_attn( query: torch.Tensor, @@ -92,6 +96,7 @@ def ref_paged_attn( @pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) +@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES) @torch.inference_mode() def test_triton_unified_attn( seq_lens: list[tuple[int, int]], @@ -103,6 +108,7 @@ def test_triton_unified_attn( soft_cap: float | None, num_blocks: int, q_dtype: torch.dtype | None, + seq_threshold_3D: int, ) -> None: torch.set_default_device("cuda") @@ -152,6 +158,28 @@ def test_triton_unified_attn( k_descale = torch.rand(scale_shape, dtype=torch.float32) v_descale = torch.rand(scale_shape, dtype=torch.float32) + num_queries_per_kv = num_query_heads // num_kv_heads + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else 1 << (num_queries_per_kv - 1).bit_length() + ) # next power of 2 value + BLOCK_Q = BLOCK_M // num_queries_per_kv + + block_q_seq_boundaries_tensor = torch.empty(num_seqs + 1, dtype=torch.int32) + if max_query_len > 1: + block_q_seq_boundaries_tensor[0] = 0 + block_q_seq_boundaries_tensor[1 : cu_query_lens.numel()].copy_( + cu_query_lens[1:] + ) + block_q_seq_boundaries_tensor[1 : cu_query_lens.numel()].sub_( + cu_query_lens[:-1] + ) + block_q_seq_boundaries_tensor[1 : cu_query_lens.numel()].add_(BLOCK_Q - 1) + block_q_seq_boundaries_tensor[1 : cu_query_lens.numel()].floor_divide_(BLOCK_Q) + block_q_seq_boundaries_tensor[: cu_query_lens.numel()].cumsum_(dim=0) + num_q_blocks = block_q_seq_boundaries_tensor[cu_query_lens.numel() - 1] + else: + num_q_blocks = len(seq_lens) + unified_attention( q=maybe_quantized_query, k=maybe_quantized_key_cache, @@ -169,6 +197,11 @@ def test_triton_unified_attn( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + BLOCK_M=BLOCK_M, + BLOCK_Q=BLOCK_Q, + num_q_blocks=num_q_blocks, + block_q_seq_boundaries_tensor=block_q_seq_boundaries_tensor, + seq_threshold_3D=seq_threshold_3D, ) ref_output = ref_paged_attn( diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 565be1c39bec..00c7c5be3925 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -98,28 +98,39 @@ def kernel_unified_attention_2d( BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int + block_q_seq_boundaries_ptr, # [num_seqs+1] + only_decode, # bool USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): - q_block_global_idx = tl.program_id(0) - kv_head_idx = tl.program_id(1) + if only_decode: + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True - ) + q_block_local_idx = 0 + cur_batch_in_all_start_index = seq_idx + cur_batch_query_len = 1 - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + else: + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) - q_block_local_idx = q_block_global_idx - q_block_start_idx + seq_idx = find_seq_idx( + block_q_seq_boundaries_ptr, q_block_global_idx, num_seqs, BLOCK_Q, False + ) - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + q_block_start_idx = tl.load(block_q_seq_boundaries_ptr + seq_idx) - cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + q_block_local_idx = q_block_global_idx - q_block_start_idx - if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: - return + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) @@ -749,6 +760,11 @@ def unified_attention( q_descale, k_descale, v_descale, + BLOCK_M=None, + BLOCK_Q=None, + num_q_blocks=None, + block_q_seq_boundaries_tensor=None, + seq_threshold_3D=None, alibi_slopes=None, output_scale=None, qq_bias=None, @@ -771,33 +787,49 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = ( - 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) - ) - BLOCK_Q = BLOCK_M // num_queries_per_kv - - # Ideally we would launch with kernel with: - # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. - # However, it is slow to realize the query_lens on cpu. - # Instead we use upper-bound: - # \sum_i[ceil(query_len[i] / BLOCK_Q)] - # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] - # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs - # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs - # = floor(q.shape[0] / BLOCK_Q) + num_seqs - total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + # Assign the following variables if they are not assigned in the attention metadata. + # This ensures backward compatibility with callers using an earlier version of this + # function. However, it is recommended to include these assignments in the + # attention metadata itself, as performing them here may negatively impact + # performance. + if ( + BLOCK_M is None + or BLOCK_Q is None + or num_q_blocks is None + or block_q_seq_boundaries_tensor is None + or seq_threshold_3D is None + ): + BLOCK_M = ( + 16 + if num_queries_per_kv <= 16 + else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_Q = BLOCK_M // num_queries_per_kv + + block_q_seq_boundaries_tensor = torch.empty( + num_seqs + 1, dtype=torch.int32, device=cu_seqlens_q.device + ) + block_q_seq_boundaries_tensor[0] = 0 + block_q_seq_boundaries_tensor[1:].copy_(cu_seqlens_q[1:]) + block_q_seq_boundaries_tensor[1:].sub_(cu_seqlens_q[:-1]) + block_q_seq_boundaries_tensor[1:].add_(BLOCK_Q - 1) + block_q_seq_boundaries_tensor[1:].floor_divide_(BLOCK_Q) + block_q_seq_boundaries_tensor.cumsum_(dim=0) + num_q_blocks = block_q_seq_boundaries_tensor[-1] + seq_threshold_3D = 128 // num_kv_heads # Assigning default tile sizes for prefill and decode. # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) # and at least 16 for all other data types. - TILE_SIZE_PREFILL = 32 - TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 + TILE_SIZE_2D_PREFILL = 32 + TILE_SIZE_2D_DECODE = 32 + TILE_SIZE_3D_DECODE = 16 if q.element_size() >= 2 else 32 # if batch contains a prefill - if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + if max_seqlen_q > 1: kernel_unified_attention_2d[ ( - total_num_q_blocks, + num_q_blocks, num_kv_heads, ) ]( @@ -824,7 +856,7 @@ def unified_attention( output_stride_1=out.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_PREFILL, + TILE_SIZE=TILE_SIZE_2D_PREFILL, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -844,98 +876,160 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, + block_q_seq_boundaries_ptr=block_q_seq_boundaries_tensor, + only_decode=False, USE_FP8=output_scale is not None, ) else: - # for initial version, NUM_SEGMENTS = 16 is chosen as a default - # value that showed good performance in tests - NUM_SEGMENTS = 16 - - segm_output = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - triton.next_power_of_2(head_size), - dtype=torch.float32, - device=q.device, - ) - segm_max = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - segm_expsum = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) + # if the number of sequences is larger than threshold + if num_seqs > seq_threshold_3D: + kernel_unified_attention_2d[ + ( + num_seqs, + num_kv_heads, + ) + ]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_2D_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + block_q_seq_boundaries_ptr=block_q_seq_boundaries_tensor, + only_decode=True, + USE_FP8=output_scale is not None, + ) + else: + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + # for initial version, NUM_SEGMENTS = 16 is chosen as a default + # value that showed good performance in tests + NUM_SEGMENTS = 16 + + segm_output = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + triton.next_power_of_2(head_size), + dtype=torch.float32, + device=q.device, + ) + segm_max = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + segm_expsum = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) - kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - sink_ptr=sinks, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - qq_bias_ptr=qq_bias, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, - BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_DECODE, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_QQ_BIAS=use_qq_bias, - USE_SOFTCAP=(softcap > 0), - USE_SINKS=(sinks is not None), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, - BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - ) - reduce_segments[(q.shape[0], num_query_heads)]( - output_ptr=out, - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - seq_lens_ptr=seqused_k, - num_seqs=num_seqs, - num_query_heads=num_query_heads, - out_scale_inv=1 / output_scale if output_scale is not None else 1.0, - output_stride_0=out.stride(0), - output_stride_1=out.stride(1), - block_table_stride=block_table.stride(0), - TILE_SIZE=TILE_SIZE_DECODE, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - USE_FP8=output_scale is not None, - ) + kernel_unified_attention_3d[ + (total_num_q_blocks, num_kv_heads, NUM_SEGMENTS) + ]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_3D_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) + reduce_segments[(q.shape[0], num_query_heads)]( + output_ptr=out, + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + out_scale_inv=1 / output_scale if output_scale is not None else 1.0, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + TILE_SIZE=TILE_SIZE_3D_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + USE_FP8=output_scale is not None, + ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index d051a89f03bb..fd7a0a20dbee 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -17,7 +17,7 @@ triton_reshape_and_cache_flash, ) from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -26,6 +26,7 @@ ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability +from vllm.utils.math_utils import next_power_of_2 from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -36,6 +37,10 @@ logger = init_logger(__name__) +# constants +MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel + + @dataclass class TritonAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -53,6 +58,11 @@ class TritonAttentionMetadata: seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor + BLOCK_M: int + BLOCK_Q: int + num_q_blocks: int + block_q_seq_boundaries_tensor: torch.Tensor + seq_threshold_3D: int # For cascade attention. use_cascade: bool @@ -87,6 +97,43 @@ def __init__( self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() + self.block_q_seq_boundaries_tensor = torch.empty( + self.vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=device, + ) + + # Check if CUDA Graphs are enabled for decode + self.decode_cudagraph_enabled = ( + self.vllm_config.compilation_config.cudagraph_mode + in ( + CUDAGraphMode.FULL_AND_PIECEWISE, + CUDAGraphMode.FULL_DECODE_ONLY, + CUDAGraphMode.FULL, + ) + ) + + # The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv). + # A lower bound for num_q_blocks is the number of sequences. + # To ensure the minimum launch grid size is achieved, the number of sequences + # must be at least equal to the threshold below. + # If this threshold is not reached (i.e., the batch size is not large enough), + # the 3D kernel will be selected instead. + self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv + + # Modify the threshold if needed. + if self.decode_cudagraph_enabled: + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified." + + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + self.seq_threshold_3D = min( + capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -112,6 +159,33 @@ def build( block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + num_queries_per_kv = self.num_heads_q // self.num_heads_kv + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else next_power_of_2(num_queries_per_kv) + ) + BLOCK_Q = BLOCK_M // num_queries_per_kv + + if max_seq_len > 1: + self.block_q_seq_boundaries_tensor[0] = 0 + self.block_q_seq_boundaries_tensor[1 : query_start_loc.numel()].copy_( + query_start_loc[1:] + ) + self.block_q_seq_boundaries_tensor[1 : query_start_loc.numel()].sub_( + query_start_loc[:-1] + ) + self.block_q_seq_boundaries_tensor[1 : query_start_loc.numel()].add_( + BLOCK_Q - 1 + ) + self.block_q_seq_boundaries_tensor[ + 1 : query_start_loc.numel() + ].floor_divide_(BLOCK_Q) + self.block_q_seq_boundaries_tensor[: query_start_loc.numel()].cumsum_(dim=0) + self.num_q_blocks = self.block_q_seq_boundaries_tensor[ + query_start_loc.numel() - 1 + ] + else: + self.num_q_blocks = len(seq_lens) + use_cascade = common_prefix_len > 0 if use_cascade: @@ -143,6 +217,11 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + BLOCK_M=BLOCK_M, + BLOCK_Q=BLOCK_Q, + num_q_blocks=self.num_q_blocks, + block_q_seq_boundaries_tensor=self.block_q_seq_boundaries_tensor, + seq_threshold_3D=self.seq_threshold_3D, ) return attn_metadata @@ -350,6 +429,13 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + BLOCK_M = attn_metadata.BLOCK_M + BLOCK_Q = attn_metadata.BLOCK_Q + num_q_blocks = attn_metadata.num_q_blocks + block_q_seq_boundaries_tensor = attn_metadata.block_q_seq_boundaries_tensor + + seq_threshold_3D = attn_metadata.seq_threshold_3D + descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) unified_attention( @@ -370,6 +456,11 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + BLOCK_M=BLOCK_M, + BLOCK_Q=BLOCK_Q, + num_q_blocks=num_q_blocks, + block_q_seq_boundaries_tensor=block_q_seq_boundaries_tensor, + seq_threshold_3D=seq_threshold_3D, sinks=self.sinks, output_scale=output_scale, )