Skip to content

Commit 89d907a

Browse files
committed
multihead_attention: minor refactoring and optimization
1 parent 66dfb59 commit 89d907a

File tree

1 file changed

+32
-33
lines changed

1 file changed

+32
-33
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ module subroutine backward(self, input, gradient)
125125
real, allocatable :: k_heads(:, :, :, :)
126126
real, allocatable :: q_heads(:, :, :, :)
127127
real, allocatable :: dv(:, :, :, :)
128-
real, allocatable :: d_sdpa(:, :, :, :)
129-
real, allocatable :: jacobian(:, :, :)
128+
real, allocatable :: d_sdpa(:, :)
129+
real, allocatable :: jacobian(:, :)
130130
real, allocatable :: d_normalize(:, :, :, :)
131-
real, allocatable :: d_attn_matrix(:, :, :, :)
131+
real, allocatable :: dq(:, :, :, :)
132132
real, allocatable :: dk(:, :, :, :)
133133
integer :: batch, head, seq, i, j
134134

@@ -139,10 +139,10 @@ module subroutine backward(self, input, gradient)
139139
allocate(q_heads(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
140140

141141
allocate(dv(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
142-
allocate(d_sdpa(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
143-
allocate(jacobian(self % sequence_length, self % sequence_length, self % sequence_length))
142+
allocate(d_sdpa(self % sequence_length, self % sequence_length))
143+
allocate(jacobian(self % sequence_length, self % sequence_length))
144144
allocate(d_normalize(self % n_heads, self % sequence_length, self % sequence_length, self % batch_size))
145-
allocate(d_attn_matrix(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
145+
allocate(dq(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
146146
allocate(dk(self % n_heads, self % sequence_length, self % head_size, self % batch_size))
147147

148148
! calculate output layer delta
@@ -159,39 +159,38 @@ module subroutine backward(self, input, gradient)
159159
dv(head, :, :, batch) = matmul(transpose(self % attention_matrix(head, :, :, batch)), d_output(head, :, :, batch))
160160

161161
! calculate delta for attention matrix
162-
d_sdpa(head, :, :, batch) = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch)))
162+
d_sdpa = matmul(d_output(head, :, :, batch), transpose(v_heads(head, :, :, batch)))
163163

164164
! this monstrosity below is scaled derivative of softmax
165-
do concurrent(seq = 1: self % sequence_length, i = 1: self % sequence_length, j = 1: self % sequence_length)
166-
! jacobian matrix is used to calculate derivative of softmax (temporary storage)
167-
! the idea behind this if-else is that for diagonal elements, the jacobian temp
168-
! should be: `softmax(x_i) * (1 - softmax(x_i))`
169-
! for off-diagonal: `-softmax(x_i) * softmax(x_j)`
170-
! For computational efficiency (avoid more temp storages), scaling is also done here
171-
if (i == j) then
172-
jacobian(seq, i, j) = &
173-
self % attention_matrix(head, seq, i, batch) &
174-
* (1 - self % attention_matrix(head, seq, i, batch)) &
175-
* self % scaling_factor
176-
else
177-
jacobian(seq, i, j) = &
178-
- self % attention_matrix(head, seq, i, batch) &
179-
* self % attention_matrix(head, seq, j, batch) &
180-
* self % scaling_factor
181-
end if
182-
end do
183-
184-
! attention normalization delta, the last step of softmax derivative:
185-
! multiply temp jacobian matrix by the output of softmax
186165
do concurrent(seq = 1: self % sequence_length)
166+
! create jacobian matrix
167+
do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length)
168+
! jacobian matrix is used to calculate derivative of softmax (temporary storage)
169+
! the idea behind this if-else is that for diagonal elements, the jacobian temp
170+
! should be: `softmax(x_i) * (1 - softmax(x_i))`
171+
! for off-diagonal: `-softmax(x_i) * softmax(x_j)`
172+
if (i == j) then
173+
jacobian(i, j) = &
174+
self % attention_matrix(head, seq, i, batch) &
175+
* (1 - self % attention_matrix(head, seq, i, batch))
176+
else
177+
jacobian(i, j) = &
178+
- self % attention_matrix(head, seq, i, batch) &
179+
* self % attention_matrix(head, seq, j, batch)
180+
end if
181+
end do
182+
! attention normalization delta, the last step of softmax derivative:
183+
! multiply output of softmax by temp jacobian matrix
184+
! For computational efficiency (avoid more temp storages), scaling is also done here
185+
! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3]
187186
d_normalize(head, seq, :, batch) = reshape(matmul(&
188-
reshape(d_sdpa(head, seq, :, batch), [1, self % sequence_length]),&
189-
jacobian(seq, :, :)&
187+
reshape(d_sdpa(seq, :), [1, self % sequence_length]),&
188+
jacobian * self % scaling_factor&
190189
), [self % sequence_length])
191190
end do
192191

193192
! calculate delta for query
194-
d_attn_matrix(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch))
193+
dq(head, :, :, batch) = matmul(d_normalize(head, :, :, batch), k_heads(head, :, :, batch))
195194

196195
! calculate delta for key, attention matrix should be transposed unlike for query
197196
dk(head, :, :, batch) = matmul(transpose(d_normalize(head, :, :, batch)), q_heads(head, :, :, batch))
@@ -200,7 +199,7 @@ module subroutine backward(self, input, gradient)
200199
! calculate deltas for input layers
201200
call self % value_layer % backward(self % v_input, self % combine_heads(dv))
202201
call self % key_layer % backward(self % k_input, self % combine_heads(dk))
203-
call self % query_layer % backward(self % q_input, self % combine_heads(d_attn_matrix))
202+
call self % query_layer % backward(self % q_input, self % combine_heads(dq))
204203

205204
! free temporary storages
206205
deallocate(d_output)
@@ -210,7 +209,7 @@ module subroutine backward(self, input, gradient)
210209
deallocate(d_sdpa)
211210
deallocate(jacobian)
212211
deallocate(d_normalize)
213-
deallocate(d_attn_matrix)
212+
deallocate(dq)
214213
deallocate(dk)
215214
end subroutine backward
216215

0 commit comments

Comments
 (0)