Skip to content

Commit 41dbb54

Browse files
committed
redo get_at with gather, but keep around the ein notation for readability
1 parent 82a28be commit 41dbb54

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,15 @@ def forward(
265265
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
266266
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
267267

268-
fk = einx.get_at('b h [w] j d, b h i selected -> b h i selected j d', fk, selected_block_indices)
269-
fv = einx.get_at('b h [w] j d, b h i selected -> b h i selected j d', fv, selected_block_indices)
268+
# get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
269+
270+
fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
271+
fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
272+
273+
selected_block_indices = repeat(selected_block_indices, 'b h i sel -> b h i sel j d', j = fk.shape[-2], d = fk.shape[-1])
274+
275+
fk = fk.gather(3, selected_block_indices)
276+
fv = fv.gather(3, selected_block_indices)
270277

271278
# handle maybe gating
272279

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.6"
3+
version = "0.0.7"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)