You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
334
359
@@ -441,15 +466,21 @@ def forward(
441
466
442
467
# fine attention
443
468
444
-
fk, fv, fmask=tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries=self.num_grouped_queries) fortin (fk, fv, fmask))
469
+
fmask=rearrange(fmask, 'b h ... -> b h 1 ...')
445
470
446
-
fsim=einsum(fq, fk, 'b h i d, b h i j d -> b h i j') *self.scale
471
+
fq=rearrange(fq, 'b (h qh) ... -> b h qh ...', qh=self.num_grouped_queries)
472
+
473
+
fsim=einsum(fq, fk, 'b h qh i d, b h i j d -> b h qh i j') *self.scale
474
+
475
+
mask_value=-torch.finfo(fsim.dtype).max
447
476
448
477
fsim=fsim.masked_fill(~fmask, mask_value)
449
478
450
479
fattn=fsim.softmax(dim=-1)
451
480
452
-
fine_attn_out=einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
481
+
fine_attn_out=einsum(fattn, fv, 'b h qh i j, b h i j d -> b h qh i d')
482
+
483
+
fine_attn_out=rearrange(fine_attn_out, 'b h qh ... -> b (h qh) ...')
0 commit comments