diff --git a/flash_pytorch/flash_pytorch.py b/flash_pytorch/flash_pytorch.py index 197499f..8359345 100644 --- a/flash_pytorch/flash_pytorch.py +++ b/flash_pytorch/flash_pytorch.py @@ -264,7 +264,7 @@ def forward( j - sequence dimension (target) """ - b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size + b, n, device, c = x.shape[0], x.shape[-2], x.device, self.group_size # prenorm @@ -299,24 +299,23 @@ def forward( # padding for groups - padding = padding_to_multiple_of(n, g) + padding = padding_to_multiple_of(n, c) if padding > 0: quad_q, quad_k, lin_q, lin_k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v)) - mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool)) mask = F.pad(mask, (0, padding), value = False) # group along sequence - quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v)) + quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (g c) d -> b g c d', c = c), (quad_q, quad_k, lin_q, lin_k, v)) if exists(mask): - mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g) + mask = rearrange(mask, 'b (g c) -> b g 1 c', c = c) # calculate quadratic attention output - sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g + sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / c sim = sim + self.rel_pos_bias(sim) @@ -327,7 +326,7 @@ def forward( attn = attn.masked_fill(~mask, 0.) if self.causal: - causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1) + causal_mask = torch.ones((c,c), dtype = torch.bool, device = device).triu(1) attn = attn.masked_fill(causal_mask, 0.) quad_out = einsum('... i j, ... j d -> ... i d', attn, v) @@ -335,7 +334,7 @@ def forward( # calculate linear attention output if self.causal: - lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g + lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / c # exclusive cumulative sum along group dimension