Skip to content

Commit 422ae39

Browse files
committed
Generic flatten() with 2-d and 3-d inputs
1 parent a28a9be commit 422ae39

File tree

5 files changed

+119
-22
lines changed

5 files changed

+119
-22
lines changed

src/nf/nf_flatten_layer.f90

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,20 @@ module nf_flatten_layer
1818
integer, allocatable :: input_shape(:)
1919
integer :: output_size
2020

21-
real, allocatable :: gradient(:,:,:)
21+
real, allocatable :: gradient_2d(:,:)
22+
real, allocatable :: gradient_3d(:,:,:)
2223
real, allocatable :: output(:)
2324

2425
contains
2526

26-
procedure :: backward
27-
procedure :: forward
27+
procedure :: backward_2d
28+
procedure :: backward_3d
29+
generic :: backward => backward_2d, backward_3d
30+
31+
procedure :: forward_2d
32+
procedure :: forward_3d
33+
generic :: forward => forward_2d, forward_3d
34+
2835
procedure :: init
2936

3037
end type flatten_layer
@@ -39,26 +46,47 @@ end function flatten_layer_cons
3946

4047
interface
4148

42-
pure module subroutine backward(self, input, gradient)
43-
!! Apply the backward pass to the flatten layer.
49+
pure module subroutine backward_2d(self, input, gradient)
50+
!! Apply the backward pass to the flatten layer for 2D input.
51+
!! This is a reshape operation from 1-d gradient to 2-d input.
52+
class(flatten_layer), intent(in out) :: self
53+
!! Flatten layer instance
54+
real, intent(in) :: input(:,:)
55+
!! Input from the previous layer
56+
real, intent(in) :: gradient(:)
57+
!! Gradient from the next layer
58+
end subroutine backward_2d
59+
60+
pure module subroutine backward_3d(self, input, gradient)
61+
!! Apply the backward pass to the flatten layer for 3D input.
4462
!! This is a reshape operation from 1-d gradient to 3-d input.
4563
class(flatten_layer), intent(in out) :: self
4664
!! Flatten layer instance
4765
real, intent(in) :: input(:,:,:)
4866
!! Input from the previous layer
4967
real, intent(in) :: gradient(:)
5068
!! Gradient from the next layer
51-
end subroutine backward
69+
end subroutine backward_3d
70+
71+
pure module subroutine forward_2d(self, input)
72+
!! Propagate forward the layer for 2D input.
73+
!! Calling this subroutine updates the values of a few data components
74+
!! of `flatten_layer` that are needed for the backward pass.
75+
class(flatten_layer), intent(in out) :: self
76+
!! Dense layer instance
77+
real, intent(in) :: input(:,:)
78+
!! Input from the previous layer
79+
end subroutine forward_2d
5280

53-
pure module subroutine forward(self, input)
54-
!! Propagate forward the layer.
81+
pure module subroutine forward_3d(self, input)
82+
!! Propagate forward the layer for 3D input.
5583
!! Calling this subroutine updates the values of a few data components
5684
!! of `flatten_layer` that are needed for the backward pass.
5785
class(flatten_layer), intent(in out) :: self
5886
!! Dense layer instance
5987
real, intent(in) :: input(:,:,:)
6088
!! Input from the previous layer
61-
end subroutine forward
89+
end subroutine forward_3d
6290

6391
module subroutine init(self, input_shape)
6492
!! Initialize the layer data structures.

src/nf/nf_flatten_layer_submodule.f90

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,34 @@ elemental module function flatten_layer_cons() result(res)
1515
end function flatten_layer_cons
1616

1717

18-
pure module subroutine backward(self, input, gradient)
18+
pure module subroutine backward_2d(self, input, gradient)
19+
class(flatten_layer), intent(in out) :: self
20+
real, intent(in) :: input(:,:)
21+
real, intent(in) :: gradient(:)
22+
self % gradient_2d = reshape(gradient, shape(input))
23+
end subroutine backward_2d
24+
25+
26+
pure module subroutine backward_3d(self, input, gradient)
1927
class(flatten_layer), intent(in out) :: self
2028
real, intent(in) :: input(:,:,:)
2129
real, intent(in) :: gradient(:)
22-
self % gradient = reshape(gradient, shape(input))
23-
end subroutine backward
30+
self % gradient_3d = reshape(gradient, shape(input))
31+
end subroutine backward_3d
32+
33+
34+
pure module subroutine forward_2d(self, input)
35+
class(flatten_layer), intent(in out) :: self
36+
real, intent(in) :: input(:,:)
37+
self % output = pack(input, .true.)
38+
end subroutine forward_2d
2439

2540

26-
pure module subroutine forward(self, input)
41+
pure module subroutine forward_3d(self, input)
2742
class(flatten_layer), intent(in out) :: self
2843
real, intent(in) :: input(:,:,:)
2944
self % output = pack(input, .true.)
30-
end subroutine forward
45+
end subroutine forward_3d
3146

3247

3348
module subroutine init(self, input_shape)
@@ -37,8 +52,13 @@ module subroutine init(self, input_shape)
3752
self % input_shape = input_shape
3853
self % output_size = product(input_shape)
3954

40-
allocate(self % gradient(input_shape(1), input_shape(2), input_shape(3)))
41-
self % gradient = 0
55+
if (size(input_shape) == 2) then
56+
allocate(self % gradient_2d(input_shape(1), input_shape(2)))
57+
self % gradient_2d = 0
58+
else if (size(input_shape) == 3) then
59+
allocate(self % gradient_3d(input_shape(1), input_shape(2), input_shape(3)))
60+
self % gradient_3d = 0
61+
end if
4262

4363
allocate(self % output(self % output_size))
4464
self % output = 0

src/nf/nf_layer_submodule.f90

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ pure module subroutine backward_1d(self, previous, gradient)
3737

3838
type is(flatten_layer)
3939

40-
! Upstream layers permitted: input3d, conv2d, maxpool2d
40+
! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d
4141
select type(prev_layer => previous % p)
42+
type is(input2d_layer)
43+
call this_layer % backward(prev_layer % output, gradient)
4244
type is(input3d_layer)
4345
call this_layer % backward(prev_layer % output, gradient)
4446
type is(conv2d_layer)
@@ -168,8 +170,10 @@ pure module subroutine forward(self, input)
168170

169171
type is(flatten_layer)
170172

171-
! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
173+
! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d, reshape3d
172174
select type(prev_layer => input % p)
175+
type is(input2d_layer)
176+
call this_layer % forward(prev_layer % output)
173177
type is(input3d_layer)
174178
call this_layer % forward(prev_layer % output)
175179
type is(conv2d_layer)

src/nf/nf_network_submodule.f90

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,20 @@ module subroutine backward(self, output, loss)
135135
select type(next_layer => self % layers(n + 1) % p)
136136
type is(dense_layer)
137137
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
138+
138139
type is(conv2d_layer)
139140
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
141+
140142
type is(flatten_layer)
141-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
143+
if (size(self % layers(n) % layer_shape) == 2) then
144+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
145+
else
146+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
147+
end if
148+
142149
type is(maxpool2d_layer)
143150
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
151+
144152
type is(reshape3d_layer)
145153
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
146154
end select

test/test_flatten_layer.f90

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@ program test_flatten_layer
33
use iso_fortran_env, only: stderr => error_unit
44
use nf, only: dense, flatten, input, layer, network
55
use nf_flatten_layer, only: flatten_layer
6+
use nf_input2d_layer, only: input2d_layer
67
use nf_input3d_layer, only: input3d_layer
78

89
implicit none
910

1011
type(layer) :: test_layer, input_layer
1112
type(network) :: net
12-
real, allocatable :: gradient(:,:,:)
13+
real, allocatable :: gradient_3d(:,:,:), gradient_2d(:,:)
1314
real, allocatable :: output(:)
1415
logical :: ok = .true.
1516

17+
! Test 3D input
1618
test_layer = flatten()
1719

1820
if (.not. test_layer % name == 'flatten') then
@@ -59,14 +61,49 @@ program test_flatten_layer
5961
call test_layer % backward(input_layer, real([1, 2, 3, 4]))
6062

6163
select type(this_layer => test_layer % p); type is(flatten_layer)
62-
gradient = this_layer % gradient
64+
gradient_3d = this_layer % gradient_3d
6365
end select
6466

65-
if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
67+
if (.not. all(gradient_3d == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
6668
ok = .false.
6769
write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed'
6870
end if
6971

72+
! Test 2D input
73+
test_layer = flatten()
74+
input_layer = input(2, 3)
75+
call test_layer % init(input_layer)
76+
77+
if (.not. all(test_layer % layer_shape == [6])) then
78+
ok = .false.
79+
write(stderr, '(a)') 'flatten layer has an incorrect output shape for 2D input.. failed'
80+
end if
81+
82+
! Test forward pass - reshaping from 2-d to 1-d
83+
select type(this_layer => input_layer % p); type is(input2d_layer)
84+
call this_layer % set(reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))
85+
end select
86+
87+
call test_layer % forward(input_layer)
88+
call test_layer % get_output(output)
89+
90+
if (.not. all(output == [1, 2, 3, 4, 5, 6])) then
91+
ok = .false.
92+
write(stderr, '(a)') 'flatten layer correctly propagates forward for 2D input.. failed'
93+
end if
94+
95+
! Test backward pass - reshaping from 1-d to 2-d
96+
call test_layer % backward(input_layer, real([1, 2, 3, 4, 5, 6]))
97+
98+
select type(this_layer => test_layer % p); type is(flatten_layer)
99+
gradient_2d = this_layer % gradient_2d
100+
end select
101+
102+
if (.not. all(gradient_2d == reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))) then
103+
ok = .false.
104+
write(stderr, '(a)') 'flatten layer correctly propagates backward for 2D input.. failed'
105+
end if
106+
70107
net = network([ &
71108
input(1, 28, 28), &
72109
flatten(), &

0 commit comments

Comments
 (0)