Skip to content

Commit 822dfdc

Browse files
committed
last commit for the day for this project
1 parent 4a600d2 commit 822dfdc

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,19 @@ def forward(
266266

267267
fattn = fsim.softmax(dim = -1)
268268

269-
fine_out = einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
269+
fine_attn_out = einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
270270

271-
fine_out = fine_out[..., :seq_len, :]
271+
fine_attn_out = fine_attn_out[..., :seq_len, :]
272272

273273
# 3. overlapping sliding window, this is unsurprising and expected
274274

275-
local_attn_out = self.sliding_window(q, k, v)
275+
sliding_window_attn_out = self.sliding_window(q, k, v)
276276

277277
# combine strategies
278278

279279
strategy_weighted_combine = self.to_strategy_combine(inp)
280280

281-
out = einsum(strategy_weighted_combine, stack([compressed_attn_out, fine_out, local_attn_out]), 'b h n s, s b h n d -> b h n d')
281+
out = einsum(strategy_weighted_combine, stack([compressed_attn_out, fine_attn_out, sliding_window_attn_out]), 'b h n s, s b h n d -> b h n d')
282282

283283
# merge heads and combine them
284284

0 commit comments

Comments
 (0)