Skip to content

Commit e7911e9

Browse files
committed
multihead_attention: tests, add checks for attention weights
1 parent 4548316 commit e7911e9

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

test/test_multihead_attention_layer.f90

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,19 @@ subroutine test_multihead_attention_forward(attention, ok)
141141
real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size)
142142
real :: output_flat(12)
143143
integer :: output_shape(3)
144+
integer :: attn_weights_shape(4)
145+
real :: attn_weights_flat(18)
144146
integer :: expected_shape(3) = [3, 4, 1]
145147
real :: expected_output_flat(12) = [&
146148
0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126,&
147149
0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126&
148150
]
151+
integer :: expected_attn_weights_shape(4) = [2, 3, 3, 1]
152+
real :: expected_attn_weights_flat(18) = [&
153+
7.89450705E-02, 7.89450705E-02, 2.28110179E-02, 2.28110179E-02, 2.18846574E-02, 2.18846574E-02,&
154+
0.447508544, 0.447508544, 0.464612424, 0.464612424, 0.464721352, 0.464721352,&
155+
0.473546445, 0.473546445, 0.512576580, 0.512576580, 0.513393998, 0.513393998&
156+
]
149157

150158
call attention % forward(input, input, input)
151159

@@ -159,6 +167,17 @@ subroutine test_multihead_attention_forward(attention, ok)
159167
ok = .false.
160168
write(stderr, '(a)') 'forward returned incorrect values.. failed'
161169
end if
170+
171+
attn_weights_shape = shape(attention % attention_matrix)
172+
if (.not. all(attn_weights_shape.eq.expected_attn_weights_shape)) then
173+
ok = .false.
174+
write(stderr, '(a)') 'forward returned incorrect attention weights shape.. failed'
175+
end if
176+
attn_weights_flat = reshape(attention % attention_matrix, shape(attn_weights_flat))
177+
if (.not. all(attn_weights_flat.eq.expected_attn_weights_flat)) then
178+
ok = .false.
179+
write(stderr, '(a)') 'forward returned incorrect attention weights values.. failed'
180+
end if
162181
end subroutine test_multihead_attention_forward
163182

164183
subroutine test_multihead_attention_forward_reallife_shape(ok)

0 commit comments

Comments
 (0)