Skip to content

Commit 5253f42

Browse files
authored
[ROCm] Support for Whisper v1 with Aiter Unified Attention and Aiter Flash Attention (#28376)
Signed-off-by: apinge <Tong.Qiu2@amd.com>
1 parent 3085478 commit 5253f42

File tree

3 files changed

+28
-15
lines changed

3 files changed

+28
-15
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,12 +517,9 @@ def __init__(
517517
assert self.num_heads % self.num_kv_heads == 0
518518
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
519519

520-
if attn_type != AttentionType.DECODER:
520+
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
521521
raise NotImplementedError(
522-
"Encoder self-attention and "
523-
"encoder/decoder cross-attention "
524-
"are not implemented for "
525-
"FlashAttentionImpl"
522+
"Encoder self-attention is not implemented for FlashAttentionImpl"
526523
)
527524

528525
def extend_forward(
@@ -678,7 +675,14 @@ def forward(
678675
# performance to make sure it does not introduce any overhead.
679676
num_actual_tokens = attn_metadata.num_actual_tokens
680677
key_cache, value_cache = kv_cache.unbind(0)
681-
if self.kv_sharing_target_layer_name is None:
678+
# key and value may be None in the case of cross attention. They are
679+
# calculated once based on the output from the encoder and then cached
680+
# in KV cache.
681+
if (
682+
self.kv_sharing_target_layer_name is None
683+
and key is not None
684+
and value is not None
685+
):
682686
# Reshape the input keys and values and store them in the cache.
683687
# Skip this if sharing KV cache with an earlier attention layer.
684688
# NOTE(woosuk): Here, key and value are padded while slot_mapping
@@ -704,8 +708,10 @@ def forward(
704708

705709
# decode:extend:prefill
706710
query = query[:num_actual_tokens]
707-
key = key[:num_actual_tokens]
708-
value = value[:num_actual_tokens]
711+
if key is not None:
712+
key = key[:num_actual_tokens]
713+
if value is not None:
714+
value = value[:num_actual_tokens]
709715

710716
output_actual_tokens = output[:num_actual_tokens]
711717

vllm/v1/attention/backends/rocm_aiter_unified_attn.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,14 @@ def forward(
142142

143143
key_cache, value_cache = kv_cache.unbind(0)
144144

145-
if self.kv_sharing_target_layer_name is None:
145+
# key and value may be None in the case of cross attention. They are
146+
# calculated once based on the output from the encoder and then cached
147+
# in KV cache.
148+
if (
149+
self.kv_sharing_target_layer_name is None
150+
and key is not None
151+
and value is not None
152+
):
146153
# Reshape the input keys and values and store them in the cache.
147154
# Skip this if sharing KV cache with an earlier attention layer.
148155
ops.reshape_and_cache_flash(
@@ -169,7 +176,10 @@ def forward(
169176
max_seqlen_k = attn_metadata.max_seq_len
170177
block_table = attn_metadata.block_table
171178

172-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
179+
descale_shape = (
180+
cu_seqlens_q.shape[0] - 1,
181+
key.shape[1] if key is not None else self.num_kv_heads,
182+
)
173183

174184
self.unified_attention(
175185
q=query[:num_actual_tokens],

vllm/v1/attention/backends/rocm_attn.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,9 @@ def __init__(
238238

239239
RocmAttentionBackend.validate_head_size(head_size)
240240

241-
if attn_type != AttentionType.DECODER:
241+
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
242242
raise NotImplementedError(
243-
"Encoder self-attention and "
244-
"encoder/decoder cross-attention "
245-
"are not implemented for "
246-
"RocmAttentionImpl"
243+
"Encoder self-attention is not implemented for RocmAttentionImpl"
247244
)
248245

249246
self.fp8_dtype = current_platform.fp8_dtype()

0 commit comments

Comments
 (0)