Skip to content

Commit 2166509

Browse files
committed
bug fixes; integrating reshape_generalized in environment
1 parent 76a8d1c commit 2166509

File tree

5 files changed

+22
-7
lines changed

5 files changed

+22
-7
lines changed

example/cnn_mnist_1d.f90

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program cnn_mnist
22

33
use nf, only: network, sgd, &
4-
input, conv2d, maxpool2d, flatten, dense, reshape, locally_connected_1d, &
4+
input, conv2d, maxpool2d, flatten, dense, reshape, reshape_generalized, locally_connected_1d, &
55
load_mnist, label_digits, softmax, relu
66

77
implicit none
@@ -20,8 +20,11 @@ program cnn_mnist
2020

2121
net = network([ &
2222
input(784), &
23-
reshape([1,28,28]), &
24-
locally_connected_1d(filters=8, kernel_size=2, activation=relu()), &
23+
reshape([1, 28, 28]), &
24+
conv2d(filters=8, kernel_size=3, activation=relu()), &
25+
maxpool2d(pool_size=2), &
26+
conv2d(filters=16, kernel_size=3, activation=relu()), &
27+
maxpool2d(pool_size=2), &
2528
dense(10, activation=softmax()) &
2629
])
2730

src/nf/nf_layer_constructors.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ module function reshape(output_shape) result(res)
213213
end function reshape
214214

215215
module function reshape_generalized(output_shape) result(res)
216-
integer, intent(in) :: output_shape
216+
integer, intent(in) :: output_shape(:)
217217
type(layer) :: res
218218

219219
end function reshape_generalized

src/nf/nf_layer_submodule.f90

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use nf_locally_connected_1d_layer, only: locally_connected_1d_layer
1111
use nf_maxpool2d_layer, only: maxpool2d_layer
1212
use nf_reshape_layer, only: reshape3d_layer
13+
use nf_reshape_layer_generalized, only: reshape_generalized_layer
1314
use nf_optimizers, only: optimizer_base_type
1415

1516
contains
@@ -325,6 +326,8 @@ elemental module function get_num_params(self) result(num_params)
325326
num_params = 0
326327
type is (reshape3d_layer)
327328
num_params = 0
329+
type is (reshape_generalized_layer)
330+
num_params = 0
328331
class default
329332
error stop 'Unknown layer type.'
330333
end select
@@ -352,6 +355,8 @@ module function get_params(self) result(params)
352355
! No parameters to get.
353356
type is (reshape3d_layer)
354357
! No parameters to get.
358+
type is (reshape_generalized_layer)
359+
! No parameters to get.
355360
class default
356361
error stop 'Unknown layer type.'
357362
end select
@@ -379,6 +384,8 @@ module function get_gradients(self) result(gradients)
379384
! No gradients to get.
380385
type is (reshape3d_layer)
381386
! No gradients to get.
387+
type is (reshape_generalized_layer)
388+
! No gradients to get.
382389
class default
383390
error stop 'Unknown layer type.'
384391
end select
@@ -440,7 +447,12 @@ module subroutine set_params(self, params)
440447
! No parameters to set.
441448
write(stderr, '(a)') 'Warning: calling set_params() ' &
442449
// 'on a zero-parameter layer; nothing to do.'
443-
450+
451+
type is (reshape_generalized_layer)
452+
! No parameters to set.
453+
write(stderr, '(a)') 'Warning: calling set_params() ' &
454+
// 'on a zero-parameter layer; nothing to do.'
455+
444456
class default
445457
error stop 'Unknown layer type.'
446458
end select

src/nf/nf_network_submodule.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
use nf_maxpool2d_layer, only: maxpool2d_layer
1111
use nf_reshape_layer, only: reshape3d_layer
1212
use nf_layer, only: layer
13-
use nf_layer_constructors, only: conv2d, dense, flatten, input, locally_connected_1d, maxpool2d, reshape
13+
use nf_layer_constructors, only: conv2d, dense, flatten, input, locally_connected_1d, maxpool2d, reshape, reshape_generalized
1414
use nf_loss, only: quadratic
1515
use nf_optimizers, only: optimizer_base_type, sgd
1616
use nf_parallel, only: tile_indices

src/nf/nf_reshape_generalized_submodule.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pure module function reshape_layer_cons(output_shape) result(res)
1111
type(reshape_generalized_layer) :: res
1212

1313
! Check if output_shape is scalar (size 1)
14-
if (size(output_shape) == 1) then
14+
if (size(output_shape) == 0) then
1515
allocate(res % output_shape(1))
1616
res % output_shape = output_shape
1717
else

0 commit comments

Comments
 (0)