Skip to content

Commit 5eff064

Browse files
committed
ok, it lines up with fig2 now
1 parent 37ad19b commit 5eff064

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

native_sparse_attention_pytorch/nsa.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def forward(
218218
fk = k
219219
fv = v
220220

221+
221222
if seq_len < fine_divisible_seq_len:
222223
remainder = fine_divisible_seq_len - seq_len
223224
fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2)
@@ -228,13 +229,30 @@ def forward(
228229

229230
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
230231

232+
# handle block causal diagonal in the diagram, but run experiments without to see
233+
234+
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
235+
fine_window_seq = rearrange(fine_window_seq, 'n -> n 1').expand_as(selected_block_indices)
236+
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
237+
238+
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
239+
240+
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
241+
causal_mask = repeat(causal_mask, 'i j -> (w i) 1 j', w = num_fine_blocks).expand_as(fmask)
242+
243+
fmask = cat((fmask, causal_mask), dim = -2)
244+
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
245+
246+
# select out the spatial crops of keys / values for fine attention
247+
231248
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
232249
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
233-
fmask = repeat(fmask, 'b h i w -> b h i (w j)', j = self.selection_block_size)
234250

235251
fk = einx.get_at('b h [w] j d, b h i selected -> b h i (selected j) d', fk, selected_block_indices)
236252
fv = einx.get_at('b h [w] j d, b h i selected -> b h i (selected j) d', fv, selected_block_indices)
237253

254+
# fine attention
255+
238256
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
239257

240258
fsim = fsim.masked_fill(fmask, mask_value)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.1"
3+
version = "0.0.2"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)