@@ -6,12 +6,12 @@ program test_linear2d_layer
66 logical :: ok = .true.
77 real :: sample_input(3 , 4 ) = reshape (&
88 [0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 ],&
9- [3 , 4 ]) ! first batch are 0.1, second 0.2
9+ [3 , 4 ])
1010 real :: sample_gradient(3 , 1 ) = reshape ([2 ., 2 ., 3 .], [3 , 1 ])
1111 type (linear2d_layer) :: linear
1212
13- linear = linear2d_layer(sequence_length = 3 , in_features = 4 , out_features= 1 )
14- call linear % init([4 ])
13+ linear = linear2d_layer(out_features= 1 )
14+ call linear % init([3 , 4 ])
1515
1616 call test_linear2d_layer_forward(linear, ok, sample_input)
1717 call test_linear2d_layer_backward(linear, ok, sample_input, sample_gradient)
@@ -131,8 +131,8 @@ subroutine test_linear2d_layer_gradient_updates(ok)
131131
132132 integer :: i
133133
134- linear = linear2d_layer(sequence_length = 3 , in_features = 4 , out_features= 2 , batch_size = 1 )
135- call linear % init([4 ])
134+ linear = linear2d_layer(out_features= 2 )
135+ call linear % init([3 , 4 ])
136136 call linear % forward(input)
137137 call linear % backward(input, gradient)
138138
0 commit comments