Skip to content

Commit 1de7c94

Browse files
committed
coordinate descent was unstable, just use a one hot straight through instead
1 parent 839d906 commit 1de7c94

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import torch.nn.functional as F
88
from torch.nn import Module, ModuleList
99

10-
from colt5_attention import topk as differentiable_topk
11-
1210
from local_attention import LocalAttention
1311

1412
from rotary_embedding_torch import RotaryEmbedding
@@ -87,7 +85,6 @@ def __init__(
8785
num_compressed_mem_kv = 4,
8886
norm = True,
8987
use_diff_topk = False,
90-
diff_topk_coor_descent_iters = 10.
9188
):
9289
super().__init__()
9390
self.heads = heads
@@ -142,7 +139,6 @@ def __init__(
142139
# selection related
143140

144141
self.use_diff_topk = use_diff_topk
145-
self.diff_topk_coor_descent_iters = diff_topk_coor_descent_iters
146142

147143
self.selection_block_size = selection_block_size
148144
self.num_selected_blocks = num_selected_blocks
@@ -222,12 +218,12 @@ def forward(
222218

223219
# 2. fine attention over selected based on compressed attention logits
224220

225-
importance_scores = csim[..., num_mem_compress_kv:]
221+
importance_scores = cattn[..., num_mem_compress_kv:]
222+
223+
selected_importance_values, selected_block_indices = importance_scores.topk(self.num_selected_blocks, dim = -1)
226224

227225
if self.use_diff_topk:
228-
selected_importance_values, selected_block_indices, _, gates = differentiable_topk(importance_scores, self.num_selected_blocks, fused = True)
229-
else:
230-
selected_importance_values, selected_block_indices = importance_scores.topk(self.num_selected_blocks, dim = -1)
226+
gates = selected_importance_values + (1. - selected_importance_values).detach()
231227

232228
fmask = selected_importance_values > mask_value
233229

@@ -247,6 +243,9 @@ def forward(
247243

248244
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
249245

246+
if self.use_diff_topk:
247+
gates = pad_at_dim(gates, (0, remainder), value = 1., dim = -2)
248+
250249
# handle block causal diagonal in the diagram, but run experiments without to see
251250

252251
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
@@ -272,7 +271,7 @@ def forward(
272271
# handle maybe gating
273272

274273
if self.use_diff_topk:
275-
gates = F.pad(gates, (0, 1, 0, remainder), value = 1.)
274+
gates = F.pad(gates, (0, 1), value = 1.)
276275

277276
fk = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fk)
278277
fv = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fv)

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.4"
3+
version = "0.0.5"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -23,7 +23,6 @@ classifiers=[
2323
]
2424

2525
dependencies = [
26-
"CoLT5-attention>=0.11.1",
2726
"einx>=0.3.0",
2827
"einops>=0.8.0",
2928
"local-attention>=1.11.1",

train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def base_decoding(
9898
compress_block_size = 32,
9999
selection_block_size = 32,
100100
num_selected_blocks = 2,
101+
use_diff_topk = False
101102
)
102103
).cuda()
103104

0 commit comments

Comments
 (0)