Skip to content

Commit 0fad316

Browse files
committed
initial forward needs to return cache with rotated keys
1 parent b77f9d0 commit 0fad316

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,6 @@ def forward(
532532

533533
q, k, v = map(self.split_heads, (q, k, v))
534534

535-
# handle cache
536-
537-
if return_cache:
538-
cache_kv = (k, v)
539-
540535
# compressed key / values - variables prepended with `c` stands for compressed
541536

542537
k_pos = repeat(self.k_intrablock_positions, 'h n d -> h (r n) d', r = num_compress_blocks)
@@ -573,7 +568,13 @@ def forward(
573568
compressed_attn_out, csim = attend(cq, ck, cv, mask = cmask, return_sim = True)
574569

575570
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
576-
rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
571+
572+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
573+
574+
# handle cache
575+
576+
if return_cache:
577+
cache_kv = (k, v)
577578

578579
# 2. fine attention over selected based on compressed attention logits - variables prepended with `f` stands for the fine attention pathway
579580

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.71"
3+
version = "0.0.72"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)