Skip to content

Commit 37ad19b

Browse files
committed
it appears they also have block causal attention across the diagonal for the fine attention, in the diagrams, but not expanded upon in the paper. window the queries to ready for this, in case this is important
1 parent 78ec6b0 commit 37ad19b

File tree

1 file changed

+9
-1
lines changed
  • native_sparse_attention_pytorch

1 file changed

+9
-1
lines changed

native_sparse_attention_pytorch/nsa.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,19 @@ def forward(
214214

215215
fmask = selected_importance_values > mask_value
216216

217+
fq = q
217218
fk = k
218219
fv = v
219220

220221
if seq_len < fine_divisible_seq_len:
221222
remainder = fine_divisible_seq_len - seq_len
222223
fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2)
223224
fv = pad_at_dim(fv, (0, remainder), value = 0., dim = -2)
225+
fq = pad_at_dim(fq, (0, remainder), value = 0., dim = -2)
226+
227+
fmask = pad_at_dim(fmask, (0, remainder), value = False, dim = -2)
228+
229+
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
224230

225231
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
226232
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
@@ -229,14 +235,16 @@ def forward(
229235
fk = einx.get_at('b h [w] j d, b h i selected -> b h i (selected j) d', fk, selected_block_indices)
230236
fv = einx.get_at('b h [w] j d, b h i selected -> b h i (selected j) d', fv, selected_block_indices)
231237

232-
fsim = einsum(q, fk, 'b h i d, b h i j d -> b h i j') * self.scale
238+
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
233239

234240
fsim = fsim.masked_fill(fmask, mask_value)
235241

236242
fattn = fsim.softmax(dim = -1)
237243

238244
fine_out = einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
239245

246+
fine_out = fine_out[..., :seq_len, :]
247+
240248
# 3. overlapping sliding window, this is unsurprising and expected
241249

242250
local_attn_out = self.sliding_window(q, k, v)

0 commit comments

Comments
 (0)