|
2 | 2 |
|
3 | 3 | from copy import deepcopy
|
4 | 4 | from math import ceil
|
| 5 | +from functools import partial |
5 | 6 |
|
6 | 7 | import torch
|
7 | 8 | import torch.nn.functional as F
|
@@ -483,7 +484,7 @@ def forward_inference(
|
483 | 484 |
|
484 | 485 | # block causal diagonal
|
485 | 486 |
|
486 |
| - fine_sliding_window = (seq_len % self.selection_block_size) + 1 |
| 487 | + fine_sliding_window = ((seq_len - 1) % self.selection_block_size) + 1 |
487 | 488 | fk = k[..., -fine_sliding_window:, :]
|
488 | 489 | fv = v[..., -fine_sliding_window:, :]
|
489 | 490 |
|
@@ -721,6 +722,9 @@ def forward(
|
721 | 722 | num_selected = min(num_selected, importance_scores.shape[-1])
|
722 | 723 | has_selected_kv_for_fine_attn = num_selected > 0
|
723 | 724 |
|
| 725 | + remainder = fine_divisible_seq_len - seq_len |
| 726 | + pad_to_multiple = partial(pad_at_dim, pad = (0, remainder), dim = -2) |
| 727 | + |
724 | 728 | if has_selected_kv_for_fine_attn:
|
725 | 729 |
|
726 | 730 | # get the top-n kv segments for fine attention
|
@@ -760,10 +764,9 @@ def forward(
|
760 | 764 | fmask = selected_importance_values > 1e-10
|
761 | 765 |
|
762 | 766 | if seq_len < fine_divisible_seq_len:
|
763 |
| - remainder = fine_divisible_seq_len - seq_len |
764 |
| - fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2) |
765 |
| - fv = pad_at_dim(fv, (0, remainder), value = 0., dim = -2) |
766 |
| - fq = pad_at_dim(fq, (0, remainder), value = 0., dim = -2) |
| 767 | + fk = pad_to_multiple(fk) |
| 768 | + fv = pad_to_multiple(fv) |
| 769 | + fq = pad_to_multiple(fq) |
767 | 770 |
|
768 | 771 | fmask = pad_at_dim(fmask, (0, remainder), value = False, dim = -2)
|
769 | 772 |
|
@@ -845,11 +848,20 @@ def forward(
|
845 | 848 | seq_len = fk.shape[-2]
|
846 | 849 | fmask = None
|
847 | 850 |
|
| 851 | + fk = pad_to_multiple(fk) |
| 852 | + fv = pad_to_multiple(fv) |
| 853 | + fq = pad_to_multiple(fq) |
| 854 | + |
| 855 | + fq, fk, fv = tuple(rearrange(t, 'b h (w n) d -> (b w) h n d', n = self.selection_block_size) for t in (fq, fk, fv)) |
| 856 | + |
848 | 857 | if self.causal:
|
849 |
| - fmask = causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).tril() |
| 858 | + fmask = causal_mask = torch.ones((self.selection_block_size, self.selection_block_size), device = device, dtype = torch.bool).tril() |
850 | 859 |
|
851 | 860 | fine_attn_out = attend(fq, fk, fv, mask = fmask)
|
852 | 861 |
|
| 862 | + fine_attn_out = rearrange(fine_attn_out, '(b w) h n d -> b h (w n) d', b = batch) |
| 863 | + fine_attn_out = fine_attn_out[..., :seq_len, :] |
| 864 | + |
853 | 865 | # 3. overlapping sliding window, this is unsurprising and expected - `s` for sliding
|
854 | 866 |
|
855 | 867 | sq = q
|
|
0 commit comments