@@ -55,9 +55,10 @@ end function multihead_attention_layer_cons
5555 interface
5656
5757 module subroutine backward (self , input , gradient )
58- ! ! Apply the backward gradient descent pass.
59- ! ! Only weight and bias gradients are updated in this subroutine,
60- ! ! while the weights and biases themselves are untouched.
58+ ! ! General backprop for MultiHead Attention mechanism
59+ ! ! Might be used for both Self and Cross Attention
60+ ! ! Self Attention: sum output gradients
61+ ! ! Cross Attention: use them separately
6162 class(multihead_attention_layer), intent (in out ) :: self
6263 ! ! Dense layer instance
6364 real , intent (in ) :: input(:, :, :)
@@ -67,6 +68,10 @@ module subroutine backward(self, input, gradient)
6768 end subroutine backward
6869
6970 module subroutine forward (self , query , key , value )
71+ ! ! General forward propagation for MultiHead Attention Mechanism
72+ ! ! Might be used for both Self and Cross Attention
73+ ! ! Self Attention: pass the same value thrice
74+ ! ! Cross Attention: pass three values for your query, key and value
7075 class(multihead_attention_layer), intent (in out ) :: self
7176 real , intent (in ) :: query(:, :, :), key(:, :, :), value(:, :, :)
7277 end subroutine forward
@@ -76,9 +81,7 @@ module subroutine init(self, input_shape)
7681 ! !
7782 ! ! This is a deferred procedure from the `base_layer` abstract type.
7883 class(multihead_attention_layer), intent (in out ) :: self
79- ! ! Dense layer instance
8084 integer , intent (in ) :: input_shape(:)
81- ! ! Shape of the input layer
8285 end subroutine init
8386
8487 end interface
@@ -115,7 +118,6 @@ module function multihead_attention_layer_cons(&
115118 end function multihead_attention_layer_cons
116119
117120 module subroutine backward (self , input , gradient )
118- ! ! General backprop for MultiHead Attention mechanism
119121 class(multihead_attention_layer), intent (in out ) :: self
120122 real , intent (in ) :: input(:, :, :)
121123 real , intent (in ) :: gradient(:, :, :)
@@ -214,7 +216,6 @@ module subroutine backward(self, input, gradient)
214216 end subroutine backward
215217
216218 module subroutine forward (self , query , key , value )
217- ! ! General forward prop for MultiHead Attention Mechenism
218219 class(multihead_attention_layer), intent (in out ) :: self
219220 real , intent (in ) :: query(:, :, :), key(:, :, :), value(:, :, :)
220221
@@ -261,17 +262,8 @@ module function split_heads(self, input) result(output)
261262 ! ! Split inputs into heads
262263 ! !
263264 ! ! Example with two heads:
264- ! ! input (1, 3, 4):
265- ! ! [[[0. , 0.3 , 0.6 , 0.9 ],
266- ! ! [0.1 , 0.4 , 0.7 , 0.11],
267- ! ! [0.2 , 0.5 , 0.8 , 0.12]]]
268- ! ! output (1, 2, 3, 2)
269- ! ! [[[[0. , 0.3 ],
270- ! [0.1 , 0.4 ],
271- ! [0.2 , 0.5 ]],
272- ! [[0.6 , 0.9 ],
273- ! [0.7 , 0.11],
274- ! [0.8 , 0.12]]]]
265+ ! ! input (3, 4, 1)
266+ ! ! output (2, 3, 2, 1)
275267 class(multihead_attention_layer) :: self
276268 real :: input(:, :, :)
277269 real :: output(self % n_heads, self % sequence_length, self % head_size, self % batch_size)
0 commit comments