@@ -95,8 +95,10 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
9595 row_maxes .copy_ (new_row_maxes )
9696 row_sums .copy_ (new_row_sums )
9797
98+ lse = all_row_sums .log () + all_row_maxes
99+
98100 ctx .args = (causal , scale , mask , q_bucket_size , k_bucket_size )
99- ctx .save_for_backward (q , k , v , o , all_row_sums , all_row_maxes )
101+ ctx .save_for_backward (q , k , v , o , lse )
100102
101103 return o
102104
@@ -106,7 +108,7 @@ def backward(ctx, do):
106108 """ Algorithm 4 in the paper """
107109
108110 causal , scale , mask , q_bucket_size , k_bucket_size = ctx .args
109- q , k , v , o , l , m = ctx .saved_tensors
111+ q , k , v , o , lse = ctx .saved_tensors
110112
111113 device = q .device
112114
@@ -122,12 +124,11 @@ def backward(ctx, do):
122124 o .split (q_bucket_size , dim = - 2 ),
123125 do .split (q_bucket_size , dim = - 2 ),
124126 mask ,
125- l .split (q_bucket_size , dim = - 2 ),
126- m .split (q_bucket_size , dim = - 2 ),
127+ lse .split (q_bucket_size , dim = - 2 ),
127128 dq .split (q_bucket_size , dim = - 2 )
128129 )
129130
130- for ind , (qc , oc , doc , row_mask , lc , mc , dqc ) in enumerate (row_splits ):
131+ for ind , (qc , oc , doc , row_mask , lsec , dqc ) in enumerate (row_splits ):
131132 q_start_index = ind * q_bucket_size - qk_len_diff
132133
133134 col_splits = zip (
@@ -146,12 +147,10 @@ def backward(ctx, do):
146147 causal_mask = torch .ones ((qc .shape [- 2 ], kc .shape [- 2 ]), dtype = torch .bool , device = device ).triu (q_start_index - k_start_index + 1 )
147148 attn_weights .masked_fill_ (causal_mask , max_neg_value )
148149
149- exp_attn_weights = torch .exp (attn_weights - mc )
150+ p = torch .exp (attn_weights - lsec )
150151
151152 if exists (row_mask ):
152- exp_attn_weights .masked_fill_ (~ row_mask , 0. )
153-
154- p = exp_attn_weights / lc
153+ p .masked_fill_ (~ row_mask , 0. )
155154
156155 dv_chunk = einsum ('... i j, ... i d -> ... j d' , p , doc )
157156 dp = einsum ('... i d, ... j d -> ... i j' , doc , vc )
0 commit comments