Skip to content

Commit 1233350

Browse files
authored
Revert "[ROCm][BugFix] Remove the usage of device_info from aiter (#28383)"
This reverts commit ca00b1b.
1 parent ca00b1b commit 1233350

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,15 @@
3131

3232
if current_platform.is_rocm():
3333
import aiter
34+
from aiter.ops.triton.utils.device_info import get_num_sms
3435

3536
from vllm.triton_utils import tl, triton
3637

3738
def block_size(x, head_dim):
3839
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
3940

40-
def num_programs(total_tokens):
41-
return min(total_tokens, current_platform.get_cu_count())
41+
def num_programs(head_dim):
42+
return min(head_dim, get_num_sms())
4243

4344
@triton.jit
4445
def cp_mha_gather_cache_kernel(
@@ -57,19 +58,19 @@ def cp_mha_gather_cache_kernel(
5758
x,
5859
max_block_num,
5960
num_tokens,
60-
num_programs,
6161
DEQUANT: tl.constexpr,
6262
PAGE_SIZE: tl.constexpr,
6363
CACHE_FORMAT: tl.constexpr,
6464
BLOCK_SIZE: tl.constexpr,
65+
NUM_PRGMS: tl.constexpr,
6566
):
6667
bid = tl.program_id(0)
6768
col_offsets = tl.arange(0, BLOCK_SIZE)
6869
if DEQUANT:
6970
k_scale = tl.load(k_scale_ptr)
7071
v_scale = tl.load(v_scale_ptr)
7172

72-
for token_id in tl.range(bid, num_tokens, num_programs):
73+
for token_id in tl.range(bid, num_tokens, NUM_PRGMS):
7374
key_ptr_offset = key_ptr + token_id * head_size * num_heads
7475
value_ptr_offset = value_ptr + token_id * head_size * num_heads
7576
batch_idx = tl.load(token_to_batch_ptr + token_id)
@@ -161,11 +162,11 @@ def cp_mha_gather_cache(
161162
x,
162163
block_tables.size(1),
163164
total_tokens,
164-
NUM_PRGMS,
165165
DEQUANT=dequant,
166166
PAGE_SIZE=page_size,
167167
CACHE_FORMAT=kv_cache_layout,
168168
BLOCK_SIZE=BLOCK_SIZE,
169+
NUM_PRGMS=NUM_PRGMS,
169170
)
170171

171172

0 commit comments

Comments
 (0)