|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from copy import deepcopy |
3 | 4 | from math import ceil
|
4 | 5 |
|
5 | 6 | import torch
|
@@ -85,6 +86,9 @@ def __init__(
|
85 | 86 | num_compressed_mem_kv = 4,
|
86 | 87 | norm = True,
|
87 | 88 | use_diff_topk = False,
|
| 89 | + compress_mlp: Module | None = None, |
| 90 | + compress_mlp_expand_factor = 1., |
| 91 | + |
88 | 92 | ):
|
89 | 93 | super().__init__()
|
90 | 94 | self.heads = heads
|
@@ -120,21 +124,26 @@ def __init__(
|
120 | 124 |
|
121 | 125 | assert num_compressed_mem_kv > 0
|
122 | 126 |
|
| 127 | + self.split_compress_window = Rearrange('b h (w n) d -> b h w n d', n = compress_block_size) |
| 128 | + |
123 | 129 | self.compress_mem_kv = nn.Parameter(torch.zeros(2, heads, num_compressed_mem_kv, dim_head))
|
| 130 | + |
124 | 131 | self.k_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
|
125 | 132 | self.v_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
|
126 | 133 |
|
127 |
| - self.k_compress = nn.Sequential( |
128 |
| - Rearrange('b h n d -> b (h d) n'), |
129 |
| - nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride = compress_block_size, groups = heads), |
130 |
| - Rearrange('b (h d) nc -> b h nc d', h = heads) |
131 |
| - ) |
| 134 | + if not exists(compress_mlp): |
| 135 | + compress_dim = compress_block_size * dim_head |
| 136 | + compress_mlp_dim_hidden = int(compress_mlp_expand_factor * compress_dim) |
132 | 137 |
|
133 |
| - self.v_compress = nn.Sequential( |
134 |
| - Rearrange('b h n d -> b (h d) n'), |
135 |
| - nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride = compress_block_size, groups = heads), |
136 |
| - Rearrange('b (h d) nc -> b h nc d', h = heads) |
137 |
| - ) |
| 138 | + mlp = nn.Sequential( |
| 139 | + Rearrange('b h w n d -> b h w (n d)'), |
| 140 | + nn.Linear(compress_dim, compress_mlp_dim_hidden), |
| 141 | + nn.SiLU(), |
| 142 | + nn.Linear(compress_mlp_dim_hidden, dim_head), |
| 143 | + ) |
| 144 | + |
| 145 | + self.k_compress = deepcopy(mlp) |
| 146 | + self.v_compress = deepcopy(mlp) |
138 | 147 |
|
139 | 148 | # selection related
|
140 | 149 |
|
@@ -187,8 +196,11 @@ def forward(
|
187 | 196 | k_pos = repeat(self.k_intrablock_positions, 'h n d -> h (r n) d', r = num_compress_blocks)
|
188 | 197 | v_pos = repeat(self.v_intrablock_positions, 'h n d -> h (r n) d', r = num_compress_blocks)
|
189 | 198 |
|
190 |
| - ck = self.k_compress(k[..., :compress_divisible_seq_len, :] + k_pos) |
191 |
| - cv = self.v_compress(v[..., :compress_divisible_seq_len, :] + v_pos) |
| 199 | + k_compress_input = self.split_compress_window(k[..., :compress_divisible_seq_len, :] + k_pos) |
| 200 | + v_compress_input = self.split_compress_window(v[..., :compress_divisible_seq_len, :] + v_pos) |
| 201 | + |
| 202 | + ck = self.k_compress(k_compress_input) |
| 203 | + cv = self.v_compress(v_compress_input) |
192 | 204 |
|
193 | 205 | # 1. coarse attention over compressed
|
194 | 206 |
|
|
0 commit comments