Skip to content

Commit a97f141

Browse files
committed
linear2d_layer: remove redundant constructor args
1 parent 539fde8 commit a97f141

File tree

6 files changed

+22
-22
lines changed

6 files changed

+22
-22
lines changed

example/linear2d.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ program linear2d_example
88
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], &
99
[3, 4])
1010
real :: y(3) = [0.12, 0.1, 0.3]
11-
integer, parameter :: num_iterations = 500
11+
integer, parameter :: num_iterations = 5
1212
integer :: n
1313

1414
net = network([ &
1515
input(3, 4), &
16-
linear2d(3, 4, 1), &
16+
linear2d(3, 1), &
1717
flatten() &
1818
])
1919

src/nf/nf_layer_constructors.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ module function reshape(output_shape) result(res)
185185
!! Resulting layer instance
186186
end function reshape
187187

188-
module function linear2d(sequence_length, in_features, out_features) result(res)
189-
integer, intent(in) :: sequence_length, in_features, out_features
188+
module function linear2d(sequence_length, out_features) result(res)
189+
integer, intent(in) :: sequence_length, out_features
190190
type(layer) :: res
191191
end function linear2d
192192

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ module function reshape(output_shape) result(res)
150150

151151
end function reshape
152152

153-
module function linear2d(sequence_length, in_features, out_features) result(res)
154-
integer, intent(in) :: sequence_length, in_features, out_features
153+
module function linear2d(sequence_length, out_features) result(res)
154+
integer, intent(in) :: sequence_length, out_features
155155
type(layer) :: res
156156

157157
res % name = 'linear2d'
158158
res % layer_shape = [sequence_length, out_features]
159-
allocate(res % p, source=linear2d_layer(sequence_length, in_features, out_features))
159+
allocate(res % p, source=linear2d_layer(out_features))
160160
end function linear2d
161161

162162
end submodule nf_layer_constructors_submodule

src/nf/nf_linear2d_layer.f90

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ module nf_linear2d_layer
3131
end type linear2d_layer
3232

3333
interface linear2d_layer
34-
module function linear2d_layer_cons(&
35-
sequence_length, in_features, out_features&
36-
) result(res)
37-
integer, intent(in) :: sequence_length, in_features, out_features
34+
module function linear2d_layer_cons(out_features) result(res)
35+
integer, intent(in) :: out_features
3836
type(linear2d_layer) :: res
3937
end function linear2d_layer_cons
4038
end interface linear2d_layer

src/nf/nf_linear2d_layer_submodule.f90

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@
22
use nf_base_layer, only: base_layer
33
implicit none
44
contains
5-
module function linear2d_layer_cons(&
6-
sequence_length, in_features, out_features&
7-
) result(res)
8-
integer, intent(in) :: sequence_length, in_features, out_features
5+
module function linear2d_layer_cons(out_features) result(res)
6+
integer, intent(in) :: out_features
97
type(linear2d_layer) :: res
108

11-
res % in_features = in_features
129
res % out_features = out_features
13-
res % sequence_length = sequence_length
1410
end function linear2d_layer_cons
1511

1612
module subroutine init(self, input_shape)
1713
class(linear2d_layer), intent(in out) :: self
1814
integer, intent(in) :: input_shape(:)
1915

16+
if (size(input_shape) /= 2) then
17+
error stop "Linear2D Layer accepts 2D input"
18+
end if
19+
self % sequence_length = input_shape(1)
20+
self % in_features = input_shape(2)
21+
2022
allocate(self % output(self % sequence_length, self % out_features))
2123
allocate(self % gradient(self % sequence_length, self % in_features))
2224

test/test_linear2d_layer.f90

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ program test_linear2d_layer
66
logical :: ok = .true.
77
real :: sample_input(3, 4) = reshape(&
88
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],&
9-
[3, 4]) ! first batch are 0.1, second 0.2
9+
[3, 4])
1010
real :: sample_gradient(3, 1) = reshape([2., 2., 3.], [3, 1])
1111
type(linear2d_layer) :: linear
1212

13-
linear = linear2d_layer(sequence_length=3, in_features=4, out_features=1)
14-
call linear % init([4])
13+
linear = linear2d_layer(out_features=1)
14+
call linear % init([3, 4])
1515

1616
call test_linear2d_layer_forward(linear, ok, sample_input)
1717
call test_linear2d_layer_backward(linear, ok, sample_input, sample_gradient)
@@ -131,8 +131,8 @@ subroutine test_linear2d_layer_gradient_updates(ok)
131131

132132
integer :: i
133133

134-
linear = linear2d_layer(sequence_length=3, in_features=4, out_features=2, batch_size=1)
135-
call linear % init([4])
134+
linear = linear2d_layer(out_features=2)
135+
call linear % init([3, 4])
136136
call linear % forward(input)
137137
call linear % backward(input, gradient)
138138

0 commit comments

Comments
 (0)