@@ -125,10 +125,10 @@ module subroutine backward(self, input, gradient)
125125 real , allocatable :: k_heads(:, :, :, :)
126126 real , allocatable :: q_heads(:, :, :, :)
127127 real , allocatable :: dv(:, :, :, :)
128- real , allocatable :: d_sdpa(:, :, :, : )
129- real , allocatable :: jacobian(:, :, : )
128+ real , allocatable :: d_sdpa(:, :)
129+ real , allocatable :: jacobian(:, :)
130130 real , allocatable :: d_normalize(:, :, :, :)
131- real , allocatable :: d_attn_matrix (:, :, :, :)
131+ real , allocatable :: dq (:, :, :, :)
132132 real , allocatable :: dk(:, :, :, :)
133133 integer :: batch, head, seq, i, j
134134
@@ -139,10 +139,10 @@ module subroutine backward(self, input, gradient)
139139 allocate (q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
140140
141141 allocate (dv(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
142- allocate (d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size ))
143- allocate (jacobian(self % sequence_length, self % sequence_length, self % sequence_length ))
142+ allocate (d_sdpa(self % sequence_length, self % sequence_length))
143+ allocate (jacobian(self % sequence_length, self % sequence_length))
144144 allocate (d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
145- allocate (d_attn_matrix (self % n_heads, self % sequence_length, self % head_size, self % batch_size))
145+ allocate (dq (self % n_heads, self % sequence_length, self % head_size, self % batch_size))
146146 allocate (dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
147147
148148 ! calculate output layer delta
@@ -159,39 +159,38 @@ module subroutine backward(self, input, gradient)
159159 dv(head, :, :, batch) = matmul (transpose (self % attention_matrix(head, :, :, batch)), d_output(head, :, :, batch))
160160
161161 ! calculate delta for attention matrix
162- d_sdpa(head, :, :, batch) = matmul (d_output(head, :, :, batch), transpose (v_heads(head, :, :, batch)))
162+ d_sdpa = matmul (d_output(head, :, :, batch), transpose (v_heads(head, :, :, batch)))
163163
164164 ! this monstrosity below is scaled derivative of softmax
165- do concurrent(seq = 1 : self % sequence_length, i = 1 : self % sequence_length, j = 1 : self % sequence_length)
166- ! jacobian matrix is used to calculate derivative of softmax (temporary storage)
167- ! the idea behind this if-else is that for diagonal elements, the jacobian temp
168- ! should be: `softmax(x_i) * (1 - softmax(x_i))`
169- ! for off-diagonal: `-softmax(x_i) * softmax(x_j)`
170- ! For computational efficiency (avoid more temp storages), scaling is also done here
171- if (i == j) then
172- jacobian(seq, i, j) = &
173- self % attention_matrix(head, seq, i, batch) &
174- * (1 - self % attention_matrix(head, seq, i, batch)) &
175- * self % scaling_factor
176- else
177- jacobian(seq, i, j) = &
178- - self % attention_matrix(head, seq, i, batch) &
179- * self % attention_matrix(head, seq, j, batch) &
180- * self % scaling_factor
181- end if
182- end do
183-
184- ! attention normalization delta, the last step of softmax derivative:
185- ! multiply temp jacobian matrix by the output of softmax
186165 do concurrent(seq = 1 : self % sequence_length)
166+ ! create jacobian matrix
167+ do concurrent(i = 1 : self % sequence_length, j = 1 : self % sequence_length)
168+ ! jacobian matrix is used to calculate derivative of softmax (temporary storage)
169+ ! the idea behind this if-else is that for diagonal elements, the jacobian temp
170+ ! should be: `softmax(x_i) * (1 - softmax(x_i))`
171+ ! for off-diagonal: `-softmax(x_i) * softmax(x_j)`
172+ if (i == j) then
173+ jacobian(i, j) = &
174+ self % attention_matrix(head, seq, i, batch) &
175+ * (1 - self % attention_matrix(head, seq, i, batch))
176+ else
177+ jacobian(i, j) = &
178+ - self % attention_matrix(head, seq, i, batch) &
179+ * self % attention_matrix(head, seq, j, batch)
180+ end if
181+ end do
182+ ! attention normalization delta, the last step of softmax derivative:
183+ ! multiply output of softmax by temp jacobian matrix
184+ ! For computational efficiency (avoid more temp storages), scaling is also done here
185+ ! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3]
187186 d_normalize(head, seq, :, batch) = reshape (matmul (&
188- reshape (d_sdpa(head, seq, :, batch ), [1 , self % sequence_length]),&
189- jacobian(seq, :, :) &
187+ reshape (d_sdpa(seq, :), [1 , self % sequence_length]),&
188+ jacobian * self % scaling_factor &
190189 ), [self % sequence_length])
191190 end do
192191
193192 ! calculate delta for query
194- d_attn_matrix (head, :, :, batch) = matmul (d_normalize(head, :, :, batch), k_heads(head, :, :, batch))
193+ dq (head, :, :, batch) = matmul (d_normalize(head, :, :, batch), k_heads(head, :, :, batch))
195194
196195 ! calculate delta for key, attention matrix should be transposed unlike for query
197196 dk(head, :, :, batch) = matmul (transpose (d_normalize(head, :, :, batch)), q_heads(head, :, :, batch))
@@ -200,7 +199,7 @@ module subroutine backward(self, input, gradient)
200199 ! calculate deltas for input layers
201200 call self % value_layer % backward(self % v_input, self % combine_heads(dv))
202201 call self % key_layer % backward(self % k_input, self % combine_heads(dk))
203- call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix ))
202+ call self % query_layer % backward(self % q_input, self % combine_heads(dq ))
204203
205204 ! free temporary storages
206205 deallocate (d_output)
@@ -210,7 +209,7 @@ module subroutine backward(self, input, gradient)
210209 deallocate (d_sdpa)
211210 deallocate (jacobian)
212211 deallocate (d_normalize)
213- deallocate (d_attn_matrix )
212+ deallocate (dq )
214213 deallocate (dk)
215214 end subroutine backward
216215
0 commit comments