Skip to content

Commit 832050a

Browse files
authored
Merge pull request #25 from Mr-Grin/main
compress_block_sliding_stride
2 parents 69a7691 + 0f32e40 commit 832050a

File tree

5 files changed

+58
-73
lines changed

5 files changed

+58
-73
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ attn = SparseAttention(
3434
heads = 8,
3535
sliding_window_size = 2,
3636
compress_block_size = 4,
37+
compress_block_sliding_stride = 2,
3738
selection_block_size = 4,
3839
num_selected_blocks = 2
3940
)

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def sliding_mask(_, __, q_idx, kv_idx):
5858
block_mask = create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
5959
return block_mask
6060

61-
def create_compress_mask(seq_len, kv_seq_len, compress_block_size, mem_kv_len = 0, causal = True):
61+
def create_compress_mask(seq_len, kv_seq_len, compress_block_sliding_stride, mem_kv_len = 0, causal = True):
6262

6363
if not causal:
6464
return None
@@ -70,7 +70,7 @@ def compress_mask(_, __, q_idx, kv_idx):
7070
is_mem_kv = kv_idx < mem_kv_len
7171

7272
kv_without_mem = kv_idx - mem_kv_len
73-
compress_kv_idx = (kv_without_mem * compress_block_size) + (compress_block_size - 1)
73+
compress_kv_idx = (kv_without_mem * compress_block_sliding_stride) + (compress_block_sliding_stride - 1)
7474

7575
causal_mask = q_idx > compress_kv_idx
7676
return causal_mask | is_mem_kv
@@ -193,9 +193,9 @@ def __init__(
193193
heads,
194194
sliding_window_size,
195195
compress_block_size,
196+
compress_block_sliding_stride,
196197
selection_block_size,
197198
num_selected_blocks,
198-
compress_block_overlap_len = 0, # the amount of overlap of a given compression block to the previous block
199199
kv_heads = None,
200200
num_compressed_mem_kv = 1,
201201
causal = False,
@@ -261,40 +261,28 @@ def __init__(
261261
# compress strategy
262262

263263
self.compress_block_size = compress_block_size
264+
self.compress_block_sliding_stride = compress_block_sliding_stride
265+
assert self.compress_block_size >= self.compress_block_sliding_stride, 'compress_block_size must be >= compress_block_sliding_stride'
266+
assert self.compress_block_sliding_stride > 0, 'compress_block_sliding_stride must be greater than 0'
267+
assert divisible_by(selection_block_size, self.compress_block_sliding_stride), f'selection_block_size {selection_block_size} must be divisible by compress_block_sliding_stride {self.compress_block_sliding_stride}'
268+
269+
# Compression window splitting
270+
self.split_compress_window = nn.Sequential(
271+
Rearrange('b h n d -> (b h) d 1 n'),
272+
nn.ZeroPad2d(((compress_block_size - compress_block_sliding_stride), 0, 0, 0)),
273+
nn.Unfold(kernel_size=(1, self.compress_block_size), stride=(1, self.compress_block_sliding_stride)),
274+
Rearrange('(b h) (d n) w -> b h w n d', d=dim_head, h=kv_heads, n=self.compress_block_size)
275+
)
264276

265277
assert num_compressed_mem_kv > 0
266-
267-
# the function for splitting out the compression windows for the mlp
268-
269-
compress_block_has_overlap = compress_block_overlap_len > 0
270-
compress_window_size = compress_block_size + compress_block_overlap_len
271-
272-
if not compress_block_has_overlap:
273-
split_compress_window_fn = Rearrange('b h (w n) d -> b h w n d', n = compress_block_size)
274-
else:
275-
split_compress_window_fn = nn.Sequential(
276-
Rearrange('b h n d -> (b h) d 1 n'),
277-
nn.ZeroPad2d((compress_block_overlap_len, 0, 0, 0)),
278-
nn.Unfold(kernel_size = (1, compress_window_size), stride = (1, compress_block_size)),
279-
Rearrange('(b h) (d n) w -> b h w n d', d = dim_head, h = kv_heads)
280-
)
281-
282-
self.split_compress_window = split_compress_window_fn
283-
self.compress_window_size = compress_window_size
284-
285-
assert compress_block_overlap_len <= compress_block_size
286-
self.compress_block_overlap_len = compress_block_overlap_len
287-
288-
# compression attention related parameters
289-
290278
self.num_mem_compress_kv = num_compressed_mem_kv
291279
self.compress_mem_kv = nn.Parameter(torch.zeros(2, kv_heads, num_compressed_mem_kv, dim_head))
292280

293-
self.k_intrablock_positions = nn.Parameter(torch.zeros(kv_heads, compress_window_size, dim_head))
294-
self.v_intrablock_positions = nn.Parameter(torch.zeros(kv_heads, compress_window_size, dim_head))
281+
self.k_intrablock_positions = nn.Parameter(torch.zeros(kv_heads, self.compress_block_size, dim_head))
282+
self.v_intrablock_positions = nn.Parameter(torch.zeros(kv_heads, self.compress_block_size, dim_head))
295283

296284
if not exists(compress_mlp):
297-
compress_dim = compress_window_size * dim_head
285+
compress_dim = self.compress_block_size * dim_head
298286
compress_mlp_dim_hidden = int(compress_mlp_expand_factor * compress_dim)
299287

300288
compress_mlp = nn.Sequential(
@@ -310,11 +298,7 @@ def __init__(
310298
# selection related
311299

312300
self.use_diff_topk = use_diff_topk
313-
314301
self.query_heads_share_selected_kv = query_heads_share_selected_kv
315-
316-
assert divisible_by(selection_block_size, compress_block_size), f'selection block size {selection_block_size} must be greater than or equal to compress block size {compress_block_size}, as well as divisible by the compress block size'
317-
318302
self.selection_block_size = selection_block_size
319303

320304
assert num_selected_blocks >= 0
@@ -376,8 +360,6 @@ def forward_inference(
376360
seq_len = cache_len + 1
377361

378362
sliding_window = self.sliding_window_size
379-
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
380-
compress_overlap_len = self.compress_block_overlap_len
381363

382364
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
383365
num_fine_blocks = fine_divisible_seq_len // self.selection_block_size
@@ -435,7 +417,7 @@ def forward_inference(
435417

436418
running_compress_seq_len = run_k.shape[-2]
437419

438-
if divisible_by(running_compress_seq_len, self.compress_block_size + compress_overlap_len):
420+
if divisible_by(running_compress_seq_len, self.compress_block_size):
439421
k_compress_input = rearrange(run_k, 'b h n d -> b h 1 n d')
440422
v_compress_input = rearrange(run_v, 'b h n d -> b h 1 n d')
441423

@@ -445,6 +427,7 @@ def forward_inference(
445427
next_ck = self.k_compress(k_compress_input)
446428
next_cv = self.v_compress(v_compress_input)
447429

430+
compress_overlap_len = self.compress_block_size - self.compress_block_sliding_stride
448431
run_kv_slice = slice(-compress_overlap_len, None) if compress_overlap_len > 0 else slice(0, 0)
449432

450433
run_k = run_k[..., run_kv_slice, :]
@@ -461,9 +444,9 @@ def forward_inference(
461444
importance_scores = csim[..., self.num_mem_compress_kv:]
462445

463446
num_compress_blocks = importance_scores.shape[-1]
464-
num_compress_per_fine = self.selection_block_size // self.compress_block_size
447+
num_compress_per_fine = self.selection_block_size // self.compress_block_sliding_stride
465448

466-
if self.compress_block_size != self.selection_block_size:
449+
if self.compress_block_sliding_stride != self.selection_block_size:
467450
compress_seq_len = round_down_mult(num_compress_blocks, num_compress_per_fine)
468451
importance_scores = importance_scores[..., :compress_seq_len]
469452
importance_scores = reduce(importance_scores, '... (j num_compress_per_fine) -> ... j', 'mean', num_compress_per_fine = num_compress_per_fine)
@@ -582,10 +565,10 @@ def forward(
582565

583566
batch, seq_len, scale, heads, kv_heads, device = *inp.shape[:2], self.scale, self.heads, self.kv_heads, inp.device
584567

585-
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_size)
586-
num_compress_blocks = compress_divisible_seq_len // self.compress_block_size
568+
compress_divisible_seq_len = round_down_mult(seq_len, self.compress_block_sliding_stride)
569+
num_compress_blocks = compress_divisible_seq_len // self.compress_block_sliding_stride
587570

588-
compress_overlap_len = self.compress_block_overlap_len
571+
compress_overlap_len = self.compress_block_size - self.compress_block_sliding_stride
589572
has_compress_overlap = compress_overlap_len > 0
590573

591574
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)
@@ -609,7 +592,7 @@ def forward(
609592
k_compress_input = self.split_compress_window(k_compress_input)
610593
v_compress_input = self.split_compress_window(v_compress_input)
611594
else:
612-
k_compress_input, v_compress_input = tuple(t.reshape(batch, kv_heads, 0, self.compress_window_size, self.dim_head) for t in (k_compress_input, v_compress_input))
595+
k_compress_input, v_compress_input = tuple(t.reshape(batch, kv_heads, 0, self.compress_block_size, self.dim_head) for t in (k_compress_input, v_compress_input))
613596

614597
# add the intra block positions
615598

@@ -645,10 +628,10 @@ def forward(
645628
# compressed masking
646629

647630
cmask = None
648-
631+
# TODO
649632
if self.causal:
650633
cq_seq = arange(seq_len, device = device)
651-
ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_size) - 1
634+
ck_seq = ((arange(num_compress_blocks, device = device) + 1) * self.compress_block_sliding_stride) - 1
652635
ck_seq = F.pad(ck_seq, (num_mem_compress_kv, 0), value = -1)
653636

654637
cmask = einx.less('j, i -> i j', ck_seq, cq_seq)
@@ -686,9 +669,9 @@ def forward(
686669

687670
if has_selected_kv_for_fine_attn:
688671

689-
if self.compress_block_size != self.selection_block_size:
672+
if self.compress_block_sliding_stride != self.selection_block_size:
690673

691-
num_compress_per_fine = self.selection_block_size // self.compress_block_size
674+
num_compress_per_fine = self.selection_block_size // self.compress_block_sliding_stride
692675

693676
round_down_score_len = round_down_mult(importance_scores.shape[-1], num_compress_per_fine)
694677
importance_scores = importance_scores[..., :round_down_score_len]
@@ -729,10 +712,7 @@ def forward(
729712

730713
selected_importance_values, selected_block_indices = importance_scores.topk(num_selected, dim = -1)
731714

732-
gates = None
733-
734-
if self.use_diff_topk:
735-
gates = straight_through(selected_importance_values, 1.)
715+
gates = straight_through(selected_importance_values, 1.) if self.use_diff_topk else None
736716

737717
if self.use_triton_kernel and not disable_triton_kernel:
738718

@@ -762,9 +742,7 @@ def forward(
762742
fmask = selected_importance_values > 1e-10
763743

764744
if seq_len < fine_divisible_seq_len:
765-
fk = pad_to_multiple(fk)
766-
fv = pad_to_multiple(fv)
767-
fq = pad_to_multiple(fq)
745+
fk, fv, fq = map(pad_to_multiple, (fk, fv, fq))
768746

769747
fmask = pad_at_dim(fmask, (0, remainder), value = False, dim = -2)
770748

@@ -846,9 +824,7 @@ def forward(
846824
seq_len = fk.shape[-2]
847825
fmask = None
848826

849-
fk = pad_to_multiple(fk)
850-
fv = pad_to_multiple(fv)
851-
fq = pad_to_multiple(fq)
827+
fk, fv, fq = map(pad_to_multiple, (fk, fv, fq))
852828

853829
fq, fk, fv = tuple(rearrange(t, 'b h (w n) d -> (b w) h n d', n = self.selection_block_size) for t in (fq, fk, fv))
854830

tests/test_custom_compress_mlp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_alternative_compress_mlp():
2626
heads = 8,
2727
sliding_window_size = 2,
2828
compress_block_size = 4,
29+
compress_block_sliding_stride=2,
2930
selection_block_size = 4,
3031
num_selected_blocks = 2,
3132
compress_mlp = compress_mlp
@@ -47,6 +48,7 @@ def test_compress_networks():
4748
heads = 8,
4849
sliding_window_size = 2,
4950
compress_block_size = 4,
51+
compress_block_sliding_stride=2,
5052
selection_block_size = 4,
5153
num_selected_blocks = 2,
5254
compress_mlp = AttentionPool(64, 4)
@@ -67,6 +69,7 @@ def test_group_mlp():
6769
heads = 8,
6870
sliding_window_size = 2,
6971
compress_block_size = 4,
72+
compress_block_sliding_stride=2,
7073
selection_block_size = 4,
7174
num_selected_blocks = 2,
7275
compress_mlp = GroupedMLP(64, 4, 8)
@@ -88,6 +91,7 @@ def test_single_projection_mlp(grouped):
8891
heads = 8,
8992
sliding_window_size = 2,
9093
compress_block_size = 4,
94+
compress_block_sliding_stride=2,
9195
selection_block_size = 4,
9296
num_selected_blocks = 2,
9397
compress_mlp = SingleProjection(64, 4, 8 if grouped else 1)
@@ -117,6 +121,7 @@ def test_compress_transformer():
117121
heads=8,
118122
sliding_window_size=64,
119123
compress_block_size=16,
124+
compress_block_sliding_stride=2,
120125
selection_block_size=16,
121126
num_selected_blocks=2,
122127
kv_heads=num_kv_heads,

tests/test_sparse_attn.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
from native_sparse_attention_pytorch import SparseAttention
88

9+
device = 'cpu'
10+
911
@pytest.mark.parametrize('use_diff_topk', (False, True))
1012
@pytest.mark.parametrize('causal', (False, True))
1113
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
1214
@pytest.mark.parametrize('kv_heads', (8, 4))
1315
@pytest.mark.parametrize('selection_block_size', (8, 16, 32))
1416
@pytest.mark.parametrize('compress_block_size', (8, 4))
15-
@pytest.mark.parametrize('compress_block_overlap_len', (0, 2))
17+
@pytest.mark.parametrize('compress_block_sliding_stride', (2, 4))
1618
@pytest.mark.parametrize('num_selected_block', (0, 2))
1719
@pytest.mark.parametrize('query_heads_share_selected_kv', (False, True))
1820
def test_sparse_attn(
@@ -22,7 +24,7 @@ def test_sparse_attn(
2224
kv_heads,
2325
selection_block_size,
2426
compress_block_size,
25-
compress_block_overlap_len,
27+
compress_block_sliding_stride,
2628
num_selected_block,
2729
query_heads_share_selected_kv,
2830
):
@@ -35,26 +37,26 @@ def test_sparse_attn(
3537
sliding_window_size = 2,
3638
selection_block_size = selection_block_size,
3739
compress_block_size = compress_block_size,
38-
compress_block_overlap_len = compress_block_overlap_len,
40+
compress_block_sliding_stride = compress_block_sliding_stride,
3941
num_selected_blocks = num_selected_block,
4042
use_diff_topk = use_diff_topk,
4143
query_heads_share_selected_kv = query_heads_share_selected_kv,
42-
)
44+
).to(device)
4345

44-
tokens = torch.randn(2, seq_len, 512)
46+
tokens = torch.randn(2, seq_len, 512).to(device)
4547

4648
attended = attn(tokens)
4749

4850
assert tokens.shape == attended.shape
4951

5052
@pytest.mark.parametrize('seq_len', (2, 8, 16))
5153
@pytest.mark.parametrize('num_selected_blocks', (0, 2))
52-
@pytest.mark.parametrize('compress_block_overlap_len', (0, 2))
53-
@pytest.mark.parametrize('selection_block_size', (5, 10, 15))
54+
@pytest.mark.parametrize('compress_block_sliding_stride', (2, 4))
55+
@pytest.mark.parametrize('selection_block_size', (8, 16, 20))
5456
def test_inference(
5557
seq_len,
5658
num_selected_blocks,
57-
compress_block_overlap_len,
59+
compress_block_sliding_stride,
5860
selection_block_size
5961
):
6062

@@ -67,10 +69,10 @@ def test_inference(
6769
compress_block_size = 5,
6870
selection_block_size = selection_block_size,
6971
num_selected_blocks = num_selected_blocks,
70-
compress_block_overlap_len = compress_block_overlap_len
71-
)
72+
compress_block_sliding_stride = compress_block_sliding_stride
73+
).to(device)
7274

73-
tokens = torch.randn(2, seq_len, 512)
75+
tokens = torch.randn(2, seq_len, 512).to(device)
7476

7577
parallel_out = attn(tokens)
7678

@@ -100,12 +102,13 @@ def test_transformer_inference(
100102
sparse_attn_kwargs = dict(
101103
sliding_window_size = 16,
102104
compress_block_size = 4,
105+
compress_block_sliding_stride = 2,
103106
selection_block_size = selection_block_size,
104107
num_selected_blocks = 2
105108
)
106-
)
109+
).to(device)
107110

108-
prompt = torch.randint(0, 256, (1, 1))
111+
prompt = torch.randint(0, 256, (1, 1)).to(device)
109112

110113
sampled = model.sample(prompt, 128, temperature = 0., use_cache_kv = False)
111114
sampled_cached = model.sample(prompt, 128, temperature = 0., use_cache_kv = True)

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
SLIDING_WINDOW_SIZE = 64
4242
COMPRESS_BLOCK_SIZE = 16
43-
COMPRESS_BLOCK_OVERLAP_LEN = 2
43+
COMPRESS_BLOCK_SLIDING_STRIDE = 8
4444

4545
FINE_BLOCK_SIZE = 16
4646
NUM_FINE_SELECTED = 4
@@ -96,10 +96,10 @@ def decode_tokens(tokens):
9696
sparse_attn_kwargs = dict(
9797
sliding_window_size = SLIDING_WINDOW_SIZE,
9898
compress_block_size = COMPRESS_BLOCK_SIZE,
99-
compress_block_overlap_len = COMPRESS_BLOCK_OVERLAP_LEN,
99+
compress_block_sliding_stride = COMPRESS_BLOCK_SLIDING_STRIDE,
100100
compress_mlp = GroupedMLP(
101101
dim_head = 64,
102-
compress_window_size = COMPRESS_BLOCK_SIZE + COMPRESS_BLOCK_OVERLAP_LEN,
102+
compress_window_size = COMPRESS_BLOCK_SIZE,
103103
heads = KV_HEADS,
104104
),
105105
selection_block_size = FINE_BLOCK_SIZE,

0 commit comments

Comments
 (0)