Skip to content

Commit 15ae8e0

Browse files
rasmithRandall Smithtjtanaa
authored
[Bugfix][CI/Test][Spec Decode] Fix illegal memory access in offline_inference/spec_decode.py (Issue 27619) (#28432)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
1 parent 0b25498 commit 15ae8e0

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/attention/ops/triton_reshape_and_cache_flash.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def triton_reshape_and_cache_flash(
9797
k_scale: torch.Tensor, # float32
9898
v_scale: torch.Tensor, # float32
9999
):
100-
num_tokens = key.shape[0]
101100
num_heads = key.shape[1]
102101
head_size = key.shape[2]
103102
block_size = key_cache.shape[1]
@@ -155,7 +154,10 @@ def triton_reshape_and_cache_flash(
155154

156155
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
157156
# using cudagraphs
158-
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
157+
grid = lambda meta: (
158+
slot_mapping.shape[0],
159+
triton.cdiv(n, meta["TILE_SIZE"]),
160+
)
159161

160162
reshape_and_cache_kernel_flash[grid](
161163
key_ptr=key,

0 commit comments

Comments
 (0)