Skip to content

Commit 44e074c

Browse files
Merge pull request #540 from yangguohao/main
【Hackathon 9th No.15】Fix variable_length_memory_efficient_attention
2 parents 2581110 + 51f369a commit 44e074c

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

tester/api_config/config_analyzer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,8 +1948,10 @@ def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
19481948
self.numpy_tensor = self.get_random_numpy_tensor(shape=self.shape, data_type=self.dtype, min=1, max=min(k_seq_len, v_seq_len))
19491949
elif self.check_arg(api_config, 5, "mask"):
19501950
# mask should between -inf and 0 (0 is included)
1951-
eps = numpy.finfo(self.dtype).eps
1952-
self.numpy_tensor = self.get_random_numpy_tensor(shape=self.shape, data_type=self.dtype, max=0 + eps)
1951+
# eps = numpy.finfo(self.dtype).eps
1952+
# self.numpy_tensor = self.get_random_numpy_tensor(shape=self.shape, data_type=self.dtype, max=0 + eps)
1953+
# mask should be -inf(masked) or 0(not masked)
1954+
self.numpy_tensor = numpy.random.randint(0, 2, size=self.shape).astype(self.dtype) * (numpy.finfo(self.dtype).min)
19531955
elif api_config.api_name == "paddle.zeros":
19541956
self.numpy_tensor = numpy.random.randint(0, 2048, size = self.shape)
19551957

tester/paddle_to_torch/rules.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6350,8 +6350,10 @@ def variable_length_memory_efficient_attention(
63506350
if key.shape[1] != num_heads:
63516351
# Repeat key and value along the num_heads dimension
63526352
repeat_factor = num_heads // key.shape[1]
6353-
key = key.repeat(1, repeat_factor, 1, 1)
6354-
value = value.repeat(1, repeat_factor, 1, 1)
6353+
# key = key.repeat(1, repeat_factor, 1, 1)
6354+
# value = value.repeat(1, repeat_factor, 1, 1)
6355+
key = key.unsqueeze(2).expand(-1,-1, repeat_factor, -1, -1).reshape(batch_size, num_heads, key_seq_len, head_size)
6356+
value = value.unsqueeze(2).expand(-1,-1, repeat_factor, -1, -1).reshape(batch_size, num_heads, key_seq_len, head_size)
63556357
# Default scale if not provided
63566358
if scale is None:
63576359
scale = math.sqrt(1.0 / head_size)
@@ -6380,8 +6382,9 @@ def variable_length_memory_efficient_attention(
63806382
qk_res = torch.matmul(query, key.transpose(-1, -2)) # [batch_size, num_heads, query_seq_len, key_seq_len]
63816383
# Apply scale
63826384
attention = qk_res * scale
6383-
attention = attention.masked_fill(~seq_mask, torch.finfo(attention.dtype).min)
6385+
# attention = attention.masked_fill(~seq_mask, torch.finfo(attention.dtype).min)
63846386
attention = attention + mask
6387+
attention = attention.masked_fill(~seq_mask, torch.finfo(attention.dtype).min)
63856388
# Softmax over the last dimension
63866389
softmax_result = torch.nn.functional.softmax(attention, dim=-1)
63876390
softmax_result = softmax_result.masked_fill(~seq_mask, 0.0)

0 commit comments

Comments
 (0)