Skip to content

Commit 583cc11

Browse files
committed
refactor compressed pathway with gqa
1 parent 57b71de commit 583cc11

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# b - batch
2222
# h - heads
23+
# qh - grouped query heads
2324
# n - sequence (token level or compressed)
2425
# w - windows, for fine or compressed
2526
# i, j - query / key sequence
@@ -295,6 +296,7 @@ def forward(
295296
k_compress_input = self.split_compress_window(k[..., :compress_divisible_seq_len, :] + k_pos)
296297
v_compress_input = self.split_compress_window(v[..., :compress_divisible_seq_len, :] + v_pos)
297298

299+
cq = q
298300
ck = self.k_compress(k_compress_input) # Equation (7) of the Native Sparse Attention paper
299301
cv = self.v_compress(v_compress_input)
300302

@@ -307,9 +309,9 @@ def forward(
307309
ck = cat((mem_ck, ck), dim = -2)
308310
cv = cat((mem_cv, cv), dim = -2)
309311

310-
ck, cv = tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries = self.num_grouped_queries) for t in (ck, cv))
312+
cq = rearrange(cq, 'b (h qh) ... -> b h qh ...', qh = self.num_grouped_queries)
311313

312-
csim = einsum(q, ck, 'b h i d, b h j d -> b h i j') * self.scale
314+
csim = einsum(cq, ck, 'b h qh i d, b h j d -> b h qh i j') * self.scale
313315

314316
cq_seq = arange(seq_len, device = device)
315317

@@ -324,7 +326,9 @@ def forward(
324326

325327
cattn = csim.softmax(dim = -1)
326328

327-
compressed_attn_out = einsum(cattn, cv, 'b h i j, b h j d -> b h i d')
329+
compressed_attn_out = einsum(cattn, cv, 'b h qh i j, b h j d -> b h qh i d')
330+
331+
compressed_attn_out, cattn = tuple(rearrange(t, 'b h qh ... -> b (h qh) ...') for t in (compressed_attn_out, cattn))
328332

329333
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
330334

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.35"
3+
version = "0.0.36"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)