Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,9 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if attn_type != AttentionType.DECODER:
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
"Encoder self-attention is not implemented for FlashAttentionImpl"
)

def extend_forward(
Expand Down Expand Up @@ -678,7 +675,14 @@ def forward(
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
Expand All @@ -704,8 +708,10 @@ def forward(

# decode:extend:prefill
query = query[:num_actual_tokens]
key = key[:num_actual_tokens]
value = value[:num_actual_tokens]
if key is not None:
key = key[:num_actual_tokens]
if value is not None:
value = value[:num_actual_tokens]

output_actual_tokens = output[:num_actual_tokens]

Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,14 @@ def forward(

key_cache, value_cache = kv_cache.unbind(0)

if self.kv_sharing_target_layer_name is None:
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
Expand All @@ -169,7 +176,10 @@ def forward(
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (
cu_seqlens_q.shape[0] - 1,
key.shape[1] if key is not None else self.num_kv_heads,
)

self.unified_attention(
q=query[:num_actual_tokens],
Expand Down
7 changes: 2 additions & 5 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,9 @@ def __init__(

RocmAttentionBackend.validate_head_size(head_size)

if attn_type != AttentionType.DECODER:
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"RocmAttentionImpl"
"Encoder self-attention is not implemented for RocmAttentionImpl"
)

self.fp8_dtype = current_platform.fp8_dtype()
Comment on lines 239 to 246

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle encoder–decoder calls without key/value tensors

The constructor now accepts AttentionType.ENCODER_DECODER, but forward still assumes key and value are always present. During decoder-side cross attention, later decode steps reuse the encoder KV cache and invoke this path with key=None/value=None. The new guard no longer blocks these calls, so chunked_prefill_paged_decode immediately dereferences key[:num_actual_tokens] and key.shape, raising an exception before any attention is computed. Either revert the constructor restriction or update forward to fall back to the cached tensors when key/value are None.

Useful? React with 👍 / 👎.

Expand Down