@@ -6350,8 +6350,10 @@ def variable_length_memory_efficient_attention(
63506350 if key.shape[1] != num_heads:
63516351 # Repeat key and value along the num_heads dimension
63526352 repeat_factor = num_heads // key.shape[1]
6353- key = key.repeat(1, repeat_factor, 1, 1)
6354- value = value.repeat(1, repeat_factor, 1, 1)
6353+ # key = key.repeat(1, repeat_factor, 1, 1)
6354+ # value = value.repeat(1, repeat_factor, 1, 1)
6355+ key = key.unsqueeze(2).expand(-1,-1, repeat_factor, -1, -1).reshape(batch_size, num_heads, key_seq_len, head_size)
6356+ value = value.unsqueeze(2).expand(-1,-1, repeat_factor, -1, -1).reshape(batch_size, num_heads, key_seq_len, head_size)
63556357 # Default scale if not provided
63566358 if scale is None:
63576359 scale = math.sqrt(1.0 / head_size)
@@ -6380,8 +6382,9 @@ def variable_length_memory_efficient_attention(
63806382 qk_res = torch.matmul(query, key.transpose(-1, -2)) # [batch_size, num_heads, query_seq_len, key_seq_len]
63816383 # Apply scale
63826384 attention = qk_res * scale
6383- attention = attention.masked_fill(~seq_mask, torch.finfo(attention.dtype).min)
6385+ # attention = attention.masked_fill(~seq_mask, torch.finfo(attention.dtype).min)
63846386 attention = attention + mask
6387+ attention = attention.masked_fill(~seq_mask, torch.finfo(attention.dtype).min)
63856388 # Softmax over the last dimension
63866389 softmax_result = torch.nn.functional.softmax(attention, dim=-1)
63876390 softmax_result = softmax_result.masked_fill(~seq_mask, 0.0)
0 commit comments