You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ck=self.k_compress(k_compress_input) # Equation (7) of the Native Sparse Attention paper
299
301
cv=self.v_compress(v_compress_input)
300
302
@@ -307,9 +309,9 @@ def forward(
307
309
ck=cat((mem_ck, ck), dim=-2)
308
310
cv=cat((mem_cv, cv), dim=-2)
309
311
310
-
ck, cv=tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries=self.num_grouped_queries) fortin (ck, cv))
312
+
cq=rearrange(cq, 'b (h qh) ... -> b h qh ...', qh=self.num_grouped_queries)
311
313
312
-
csim=einsum(q, ck, 'b h i d, b h j d -> b h i j') *self.scale
314
+
csim=einsum(cq, ck, 'b h qh i d, b h j d -> b h qh i j') *self.scale
313
315
314
316
cq_seq=arange(seq_len, device=device)
315
317
@@ -324,7 +326,9 @@ def forward(
324
326
325
327
cattn=csim.softmax(dim=-1)
326
328
327
-
compressed_attn_out=einsum(cattn, cv, 'b h i j, b h j d -> b h i d')
329
+
compressed_attn_out=einsum(cattn, cv, 'b h qh i j, b h j d -> b h qh i d')
330
+
331
+
compressed_attn_out, cattn=tuple(rearrange(t, 'b h qh ... -> b (h qh) ...') fortin (compressed_attn_out, cattn))
328
332
329
333
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
0 commit comments