44 use nf_conv2d_layer, only: conv2d_layer
55 use nf_dense_layer, only: dense_layer
66 use nf_flatten_layer, only: flatten_layer
7+ use nf_flatten2d_layer, only: flatten2d_layer
78 use nf_input1d_layer, only: input1d_layer
89 use nf_input2d_layer, only: input2d_layer
910 use nf_input3d_layer, only: input3d_layer
@@ -46,8 +47,16 @@ pure module subroutine backward_1d(self, previous, gradient)
4647 call this_layer % backward(prev_layer % output, gradient)
4748 type is (maxpool2d_layer)
4849 call this_layer % backward(prev_layer % output, gradient)
49- ! type is(linear2d_layer)
50- ! call this_layer % backward(prev_layer % output, gradient)
50+ end select
51+
52+ type is (flatten2d_layer)
53+
54+ ! Upstream layers permitted: linear2d_layer
55+ select type (prev_layer = > previous % p)
56+ type is (linear2d_layer)
57+ call this_layer % backward(prev_layer % output, gradient)
58+ type is (input2d_layer)
59+ call this_layer % backward(prev_layer % output, gradient)
5160 end select
5261
5362 end select
@@ -61,8 +70,6 @@ pure module subroutine backward_2d(self, previous, gradient)
6170 class(layer), intent (in ) :: previous
6271 real , intent (in ) :: gradient(:,:)
6372
64- ! Backward pass from a 2-d layer downstream currently implemented
65- ! only for input2d and linear2d layers
6673 select type (this_layer = > self % p)
6774
6875 type is (linear2d_layer)
@@ -193,8 +200,14 @@ pure module subroutine forward(self, input)
193200 call this_layer % forward(prev_layer % output)
194201 type is (reshape3d_layer)
195202 call this_layer % forward(prev_layer % output)
196- ! type is(linear2d_layer)
197- ! call this_layer % forward(prev_layer % output)
203+ end select
204+
205+ type is (flatten2d_layer)
206+ select type (prev_layer = > input % p)
207+ type is (linear2d_layer)
208+ call this_layer % forward(prev_layer % output)
209+ type is (input2d_layer)
210+ call this_layer % forward(prev_layer % output)
198211 end select
199212
200213 type is (reshape3d_layer)
@@ -237,6 +250,8 @@ pure module subroutine get_output_1d(self, output)
237250 allocate (output, source= this_layer % output)
238251 type is (flatten_layer)
239252 allocate (output, source= this_layer % output)
253+ type is (flatten2d_layer)
254+ allocate (output, source= this_layer % output)
240255 class default
241256 error stop ' 1-d output can only be read from an input1d, dense, or flatten layer.'
242257
@@ -308,9 +323,11 @@ impure elemental module subroutine init(self, input)
308323 self % layer_shape = shape (this_layer % output)
309324 type is (flatten_layer)
310325 self % layer_shape = shape (this_layer % output)
326+ type is (flatten2d_layer)
327+ self % layer_shape = shape (this_layer % output)
311328 end select
312329
313- self % input_layer_shape = input % layer_shape
330+ self % input_layer_shape = input % layer_shape
314331 self % initialized = .true.
315332
316333 end subroutine init
@@ -351,6 +368,8 @@ elemental module function get_num_params(self) result(num_params)
351368 num_params = 0
352369 type is (flatten_layer)
353370 num_params = 0
371+ type is (flatten2d_layer)
372+ num_params = 0
354373 type is (reshape3d_layer)
355374 num_params = 0
356375 type is (linear2d_layer)
@@ -380,6 +399,8 @@ module function get_params(self) result(params)
380399 ! No parameters to get.
381400 type is (flatten_layer)
382401 ! No parameters to get.
402+ type is (flatten2d_layer)
403+ ! No parameters to get.
383404 type is (reshape3d_layer)
384405 ! No parameters to get.
385406 type is (linear2d_layer)
@@ -408,6 +429,8 @@ module function get_gradients(self) result(gradients)
408429 type is (maxpool2d_layer)
409430 ! No gradients to get.
410431 type is (flatten_layer)
432+ ! No parameters to get.
433+ type is (flatten2d_layer)
411434 ! No gradients to get.
412435 type is (reshape3d_layer)
413436 ! No gradients to get.
@@ -473,6 +496,11 @@ module subroutine set_params(self, params)
473496 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
474497 // ' on a zero-parameter layer; nothing to do.'
475498
499+ type is (flatten2d_layer)
500+ ! No parameters to set.
501+ write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
502+ // ' on a zero-parameter layer; nothing to do.'
503+
476504 type is (reshape3d_layer)
477505 ! No parameters to set.
478506 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments