3131
3232if current_platform .is_rocm ():
3333 import aiter
34+ from aiter .ops .triton .utils .device_info import get_num_sms
3435
3536 from vllm .triton_utils import tl , triton
3637
3738 def block_size (x , head_dim ):
3839 return min (65536 // x .element_size (), triton .next_power_of_2 (head_dim ))
3940
40- def num_programs (total_tokens ):
41- return min (total_tokens , current_platform . get_cu_count ())
41+ def num_programs (head_dim ):
42+ return min (head_dim , get_num_sms ())
4243
4344 @triton .jit
4445 def cp_mha_gather_cache_kernel (
@@ -57,19 +58,19 @@ def cp_mha_gather_cache_kernel(
5758 x ,
5859 max_block_num ,
5960 num_tokens ,
60- num_programs ,
6161 DEQUANT : tl .constexpr ,
6262 PAGE_SIZE : tl .constexpr ,
6363 CACHE_FORMAT : tl .constexpr ,
6464 BLOCK_SIZE : tl .constexpr ,
65+ NUM_PRGMS : tl .constexpr ,
6566 ):
6667 bid = tl .program_id (0 )
6768 col_offsets = tl .arange (0 , BLOCK_SIZE )
6869 if DEQUANT :
6970 k_scale = tl .load (k_scale_ptr )
7071 v_scale = tl .load (v_scale_ptr )
7172
72- for token_id in tl .range (bid , num_tokens , num_programs ):
73+ for token_id in tl .range (bid , num_tokens , NUM_PRGMS ):
7374 key_ptr_offset = key_ptr + token_id * head_size * num_heads
7475 value_ptr_offset = value_ptr + token_id * head_size * num_heads
7576 batch_idx = tl .load (token_to_batch_ptr + token_id )
@@ -161,11 +162,11 @@ def cp_mha_gather_cache(
161162 x ,
162163 block_tables .size (1 ),
163164 total_tokens ,
164- NUM_PRGMS ,
165165 DEQUANT = dequant ,
166166 PAGE_SIZE = page_size ,
167167 CACHE_FORMAT = kv_cache_layout ,
168168 BLOCK_SIZE = BLOCK_SIZE ,
169+ NUM_PRGMS = NUM_PRGMS ,
169170 )
170171
171172
0 commit comments