@@ -141,11 +141,19 @@ subroutine test_multihead_attention_forward(attention, ok)
141141 real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size)
142142 real :: output_flat(12 )
143143 integer :: output_shape(3 )
144+ integer :: attn_weights_shape(4 )
145+ real :: attn_weights_flat(18 )
144146 integer :: expected_shape(3 ) = [3 , 4 , 1 ]
145147 real :: expected_output_flat(12 ) = [&
146148 0.982241452 , 1.00407875 , 1.00444126 , 0.982241452 , 1.00407875 , 1.00444126 ,&
147149 0.982241452 , 1.00407875 , 1.00444126 , 0.982241452 , 1.00407875 , 1.00444126 &
148150 ]
151+ integer :: expected_attn_weights_shape(4 ) = [2 , 3 , 3 , 1 ]
152+ real :: expected_attn_weights_flat(18 ) = [&
153+ 7.89450705E-02 , 7.89450705E-02 , 2.28110179E-02 , 2.28110179E-02 , 2.18846574E-02 , 2.18846574E-02 ,&
154+ 0.447508544 , 0.447508544 , 0.464612424 , 0.464612424 , 0.464721352 , 0.464721352 ,&
155+ 0.473546445 , 0.473546445 , 0.512576580 , 0.512576580 , 0.513393998 , 0.513393998 &
156+ ]
149157
150158 call attention % forward(input, input, input)
151159
@@ -159,6 +167,17 @@ subroutine test_multihead_attention_forward(attention, ok)
159167 ok = .false.
160168 write (stderr, ' (a)' ) ' forward returned incorrect values.. failed'
161169 end if
170+
171+ attn_weights_shape = shape (attention % attention_matrix)
172+ if (.not. all (attn_weights_shape.eq. expected_attn_weights_shape)) then
173+ ok = .false.
174+ write (stderr, ' (a)' ) ' forward returned incorrect attention weights shape.. failed'
175+ end if
176+ attn_weights_flat = reshape (attention % attention_matrix, shape (attn_weights_flat))
177+ if (.not. all (attn_weights_flat.eq. expected_attn_weights_flat)) then
178+ ok = .false.
179+ write (stderr, ' (a)' ) ' forward returned incorrect attention weights values.. failed'
180+ end if
162181 end subroutine test_multihead_attention_forward
163182
164183 subroutine test_multihead_attention_forward_reallife_shape (ok )
0 commit comments