Skip to content

Commit bc31eb7

Browse files
committed
account for learned memory key values in flex compress mask, also cleanup some regular attend logic
1 parent 583cc11 commit bc31eb7

File tree

3 files changed

+54
-31
lines changed

3 files changed

+54
-31
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,20 @@ def sliding_mask(_, __, q_idx, kv_idx):
5353
block_mask = create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
5454
return block_mask
5555

56-
def create_compress_mask(seq_len, kv_seq_len, compress_block_size):
56+
def create_compress_mask(seq_len, kv_seq_len, compress_block_size, mem_kv_len = 0):
5757
# cannot be used as using attention logits for importance score
5858
# but just to show the immense potential of flex attention
5959

6060
def compress_mask(_, __, q_idx, kv_idx):
61+
is_mem_kv = kv_idx < mem_kv_len
62+
63+
kv_without_mem = kv_idx - mem_kv_len
6164
compress_kv_idx = (kv_idx * compress_block_size) + (compress_block_size - 1)
6265

6366
causal_mask = q_idx > compress_kv_idx
64-
return causal_mask
67+
return causal_mask | is_mem_kv
6568

66-
block_mask = create_block_mask(compress_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True)
69+
block_mask = create_block_mask(compress_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len + mem_kv_len, _compile = True)
6770
return block_mask
6871

6972
def create_fine_mask(seq_len, fine_block_size):
@@ -134,6 +137,41 @@ def interpolate_1d(x, length, mode = 'bilinear'):
134137
def straight_through(t, target):
135138
return t + (target - t).detach()
136139

140+
# attend function
141+
142+
def attend(
143+
q, k, v,
144+
mask = None,
145+
return_attn = False,
146+
scale = None
147+
):
148+
scale = default(scale, q.shape[-1] ** -0.5)
149+
150+
q_heads, k_heads = q.shape[1], k.shape[1]
151+
num_grouped_queries = q_heads // k_heads
152+
153+
q = rearrange(q, 'b (h qh) ... -> b h qh ...', qh = num_grouped_queries)
154+
155+
sim = einsum(q, k, 'b h qh i d, b h j d -> b h qh i j') * scale
156+
157+
mask_value = -torch.finfo(sim.dtype).max
158+
159+
if exists(mask):
160+
sim = sim.masked_fill(~mask, mask_value)
161+
162+
attn = sim.softmax(dim = -1)
163+
164+
attn_out = einsum(attn, v, 'b h qh i j, b h j d -> b h qh i d')
165+
166+
attn_out = rearrange(attn_out, 'b h qh ... -> b (h qh) ...')
167+
168+
if not return_attn:
169+
return attn_out
170+
171+
attn = rearrange(attn, 'b h qh ... -> b (h qh) ...')
172+
173+
return attn_out, attn
174+
137175
# classes
138176

139177
class SparseAttention(Module):
@@ -309,26 +347,13 @@ def forward(
309347
ck = cat((mem_ck, ck), dim = -2)
310348
cv = cat((mem_cv, cv), dim = -2)
311349

312-
cq = rearrange(cq, 'b (h qh) ... -> b h qh ...', qh = self.num_grouped_queries)
313-
314-
csim = einsum(cq, ck, 'b h qh i d, b h j d -> b h qh i j') * self.scale
315-
316350
cq_seq = arange(seq_len, device = device)
317-
318351
ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_size) - 1
319352
ck_seq = F.pad(ck_seq, (num_mem_compress_kv, 0), value = -1)
320353

321354
cmask = einx.less('j, i -> i j', ck_seq, cq_seq)
322355

323-
mask_value = -torch.finfo(csim.dtype).max
324-
325-
csim = csim.masked_fill(~cmask, mask_value)
326-
327-
cattn = csim.softmax(dim = -1)
328-
329-
compressed_attn_out = einsum(cattn, cv, 'b h qh i j, b h j d -> b h qh i d')
330-
331-
compressed_attn_out, cattn = tuple(rearrange(t, 'b h qh ... -> b (h qh) ...') for t in (compressed_attn_out, cattn))
356+
compressed_attn_out, cattn = attend(cq, ck, cv, mask = cmask, return_attn = True)
332357

333358
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
334359

@@ -441,15 +466,21 @@ def forward(
441466

442467
# fine attention
443468

444-
fk, fv, fmask = tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, fmask))
469+
fmask = rearrange(fmask, 'b h ... -> b h 1 ...')
445470

446-
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
471+
fq = rearrange(fq, 'b (h qh) ... -> b h qh ...', qh = self.num_grouped_queries)
472+
473+
fsim = einsum(fq, fk, 'b h qh i d, b h i j d -> b h qh i j') * self.scale
474+
475+
mask_value = -torch.finfo(fsim.dtype).max
447476

448477
fsim = fsim.masked_fill(~fmask, mask_value)
449478

450479
fattn = fsim.softmax(dim = -1)
451480

452-
fine_attn_out = einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
481+
fine_attn_out = einsum(fattn, fv, 'b h qh i j, b h i j d -> b h qh i d')
482+
483+
fine_attn_out = rearrange(fine_attn_out, 'b h qh ... -> b (h qh) ...')
453484

454485
fine_attn_out = fine_attn_out[..., :seq_len, :]
455486
else:
@@ -458,15 +489,7 @@ def forward(
458489
seq_len = fk.shape[-2]
459490
fmask = causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).tril()
460491

461-
fk, fv = tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv))
462-
463-
fsim = einsum(fq, fk, 'b h i d, b h j d -> b h i j') * self.scale
464-
465-
fsim = fsim.masked_fill(~fmask, mask_value)
466-
467-
fattn = fsim.softmax(dim = -1)
468-
469-
fine_attn_out = einsum(fattn, fv, 'b h i j, b h j d -> b h i d')
492+
fine_attn_out = attend(fq, fk, fv, mask = fmask)
470493

471494
# 3. overlapping sliding window, this is unsurprising and expected - `s` for sliding
472495

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.36"
3+
version = "0.0.37"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

test_flex_masks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# compress
99

10-
print('compress mask:', create_compress_mask(1024, 256, 4))
10+
print('compress mask:', create_compress_mask(512, 128, 4, mem_kv_len = 16))
1111

1212
# fine
1313

0 commit comments

Comments
 (0)