7
7
import torch .nn .functional as F
8
8
from torch .nn import Module , ModuleList
9
9
10
- from colt5_attention import topk as differentiable_topk
11
-
12
10
from local_attention import LocalAttention
13
11
14
12
from rotary_embedding_torch import RotaryEmbedding
@@ -87,7 +85,6 @@ def __init__(
87
85
num_compressed_mem_kv = 4 ,
88
86
norm = True ,
89
87
use_diff_topk = False ,
90
- diff_topk_coor_descent_iters = 10.
91
88
):
92
89
super ().__init__ ()
93
90
self .heads = heads
@@ -142,7 +139,6 @@ def __init__(
142
139
# selection related
143
140
144
141
self .use_diff_topk = use_diff_topk
145
- self .diff_topk_coor_descent_iters = diff_topk_coor_descent_iters
146
142
147
143
self .selection_block_size = selection_block_size
148
144
self .num_selected_blocks = num_selected_blocks
@@ -222,12 +218,12 @@ def forward(
222
218
223
219
# 2. fine attention over selected based on compressed attention logits
224
220
225
- importance_scores = csim [..., num_mem_compress_kv :]
221
+ importance_scores = cattn [..., num_mem_compress_kv :]
222
+
223
+ selected_importance_values , selected_block_indices = importance_scores .topk (self .num_selected_blocks , dim = - 1 )
226
224
227
225
if self .use_diff_topk :
228
- selected_importance_values , selected_block_indices , _ , gates = differentiable_topk (importance_scores , self .num_selected_blocks , fused = True )
229
- else :
230
- selected_importance_values , selected_block_indices = importance_scores .topk (self .num_selected_blocks , dim = - 1 )
226
+ gates = selected_importance_values + (1. - selected_importance_values ).detach ()
231
227
232
228
fmask = selected_importance_values > mask_value
233
229
@@ -247,6 +243,9 @@ def forward(
247
243
248
244
selected_block_indices = pad_at_dim (selected_block_indices , (0 , remainder ), value = 0 , dim = - 2 )
249
245
246
+ if self .use_diff_topk :
247
+ gates = pad_at_dim (gates , (0 , remainder ), value = 1. , dim = - 2 )
248
+
250
249
# handle block causal diagonal in the diagram, but run experiments without to see
251
250
252
251
fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
@@ -272,7 +271,7 @@ def forward(
272
271
# handle maybe gating
273
272
274
273
if self .use_diff_topk :
275
- gates = F .pad (gates , (0 , 1 , 0 , remainder ), value = 1. )
274
+ gates = F .pad (gates , (0 , 1 ), value = 1. )
276
275
277
276
fk = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fk )
278
277
fv = einx .multiply ('b h i w, b h i w j d -> b h i w j d' , gates , fv )
0 commit comments