Skip to content

Commit a6101f5

Browse files
committed
fix
1 parent 0fad316 commit a6101f5

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def forward_inference(
361361

362362
# rotate after updating the compression running k/v
363363

364-
q = self.rotary_emb.rotate_queries_or_keys(q, offset = cache_len)
364+
rotated_q = self.rotary_emb.rotate_queries_or_keys(q, offset = cache_len)
365365
k = self.rotary_emb.rotate_queries_or_keys(k, offset = cache_len)
366366

367367
# handle cache, which stores the rotated
@@ -459,7 +459,7 @@ def forward_inference(
459459

460460
# remove later
461461

462-
fq = rearrange(q, 'b (h gh) ... -> b h gh ...', gh = self.num_grouped_queries)
462+
fq = rearrange(rotated_q, 'b (h gh) ... -> b h gh ...', gh = self.num_grouped_queries)
463463

464464
fsim = einsum(fq, fk, 'b h gh i d, b h j d -> b h gh i j') * scale
465465

@@ -476,11 +476,12 @@ def forward_inference(
476476
v = repeat(v, 'b h ... -> b (h gh) ...', gh = self.num_grouped_queries)
477477

478478
sliding_slice = (Ellipsis, slice(-(sliding_window + 1), None), slice(None))
479-
rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k[sliding_slice])
480479

481-
sim = einsum(rotated_q, rotated_k, 'b h i d, b h j d -> b h i j') * scale
480+
k, v = k[sliding_slice], v[sliding_slice]
481+
482+
sim = einsum(rotated_q, k, 'b h i d, b h j d -> b h i j') * scale
482483
attn = sim.softmax(dim = -1)
483-
sliding_window_attn_out = einsum(attn, v[sliding_slice], 'b h i j, b h j d -> b h i d')
484+
sliding_window_attn_out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
484485

485486
# combine strategies
486487

@@ -630,8 +631,8 @@ def forward(
630631

631632
# handle if number of total blocks is less than number to select for fine attention
632633

633-
fq = rotated_q
634-
fk = rotated_k
634+
fq = q
635+
fk = k
635636
fv = v
636637

637638
if has_selected_kv_for_fine_attn:
@@ -757,8 +758,8 @@ def forward(
757758

758759
# 3. overlapping sliding window, this is unsurprising and expected - `s` for sliding
759760

760-
sq = rotated_q
761-
sk = rotated_k
761+
sq = q
762+
sk = k
762763
sv = v
763764

764765
if exists(sliding_window_flex_mask):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.72"
3+
version = "0.0.73"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)