Skip to content

Commit a834a5a

Browse files
committed
fix no selected kv during inference
1 parent aa359ea commit a834a5a

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ def forward_inference(
454454

455455
# select out the sparse kv segments as defined by compressed attention map as importance score
456456

457+
fmask = None
458+
457459
if has_selected_kv_for_fine_attn:
458460
if self.query_heads_share_selected_kv:
459461
importance_scores = reduce(importance_scores, 'b (h grouped_queries) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
@@ -493,7 +495,8 @@ def forward_inference(
493495

494496
fsim = einsum(fq, fk, 'b h gh i d, b h j d -> b h gh i j') * scale
495497

496-
fsim = einx.where('b h i j, b h gh i j, -> b h gh i j', fmask, fsim, max_neg_value(fsim))
498+
if exists(fmask):
499+
fsim = einx.where('b h i j, b h gh i j, -> b h gh i j', fmask, fsim, max_neg_value(fsim))
497500

498501
fattn = fsim.softmax(dim = -1)
499502

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)