Skip to content

Commit 4548316

Browse files
committed
multihead_attention: fix comments
1 parent 89d907a commit 4548316

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ end function multihead_attention_layer_cons
5555
interface
5656

5757
module subroutine backward(self, input, gradient)
58-
!! Apply the backward gradient descent pass.
59-
!! Only weight and bias gradients are updated in this subroutine,
60-
!! while the weights and biases themselves are untouched.
58+
!! General backprop for MultiHead Attention mechanism
59+
!! Might be used for both Self and Cross Attention
60+
!! Self Attention: sum output gradients
61+
!! Cross Attention: use them separately
6162
class(multihead_attention_layer), intent(in out) :: self
6263
!! Dense layer instance
6364
real, intent(in) :: input(:, :, :)
@@ -67,6 +68,10 @@ module subroutine backward(self, input, gradient)
6768
end subroutine backward
6869

6970
module subroutine forward(self, query, key, value)
71+
!! General forward propagation for MultiHead Attention Mechanism
72+
!! Might be used for both Self and Cross Attention
73+
!! Self Attention: pass the same value thrice
74+
!! Cross Attention: pass three values for your query, key and value
7075
class(multihead_attention_layer), intent(in out) :: self
7176
real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :)
7277
end subroutine forward
@@ -76,9 +81,7 @@ module subroutine init(self, input_shape)
7681
!!
7782
!! This is a deferred procedure from the `base_layer` abstract type.
7883
class(multihead_attention_layer), intent(in out) :: self
79-
!! Dense layer instance
8084
integer, intent(in) :: input_shape(:)
81-
!! Shape of the input layer
8285
end subroutine init
8386

8487
end interface
@@ -115,7 +118,6 @@ module function multihead_attention_layer_cons(&
115118
end function multihead_attention_layer_cons
116119

117120
module subroutine backward(self, input, gradient)
118-
!! General backprop for MultiHead Attention mechanism
119121
class(multihead_attention_layer), intent(in out) :: self
120122
real, intent(in) :: input(:, :, :)
121123
real, intent(in) :: gradient(:, :, :)
@@ -214,7 +216,6 @@ module subroutine backward(self, input, gradient)
214216
end subroutine backward
215217

216218
module subroutine forward(self, query, key, value)
217-
!! General forward prop for MultiHead Attention Mechenism
218219
class(multihead_attention_layer), intent(in out) :: self
219220
real, intent(in) :: query(:, :, :), key(:, :, :), value(:, :, :)
220221

@@ -261,17 +262,8 @@ module function split_heads(self, input) result(output)
261262
!! Split inputs into heads
262263
!!
263264
!! Example with two heads:
264-
!! input (1, 3, 4):
265-
!! [[[0. , 0.3 , 0.6 , 0.9 ],
266-
!! [0.1 , 0.4 , 0.7 , 0.11],
267-
!! [0.2 , 0.5 , 0.8 , 0.12]]]
268-
!! output (1, 2, 3, 2)
269-
!! [[[[0. , 0.3 ],
270-
! [0.1 , 0.4 ],
271-
! [0.2 , 0.5 ]],
272-
! [[0.6 , 0.9 ],
273-
! [0.7 , 0.11],
274-
! [0.8 , 0.12]]]]
265+
!! input (3, 4, 1)
266+
!! output (2, 3, 2, 1)
275267
class(multihead_attention_layer) :: self
276268
real :: input(:, :, :)
277269
real :: output(self % n_heads, self % sequence_length, self % head_size, self % batch_size)

0 commit comments

Comments
 (0)