@@ -61,24 +61,32 @@ pure module function get_num_params(self) result(num_params)
6161 end function get_num_params
6262
6363
64- pure module function get_params(self) result(params)
65- class(dense_layer), intent (in ) :: self
64+ module function get_params (self ) result(params)
65+ class(dense_layer), intent (in ), target :: self
6666 real , allocatable :: params(:)
6767
68+ real , pointer :: w_(:) = > null ()
69+
70+ w_(1 :size (self % weights)) = > self % weights
71+
6872 params = [ &
69- pack (self % weights, .true. ) , &
73+ w_ , &
7074 self % biases &
7175 ]
7276
7377 end function get_params
7478
7579
76- pure module function get_gradients(self) result(gradients)
77- class(dense_layer), intent (in ) :: self
80+ module function get_gradients (self ) result(gradients)
81+ class(dense_layer), intent (in ), target :: self
7882 real , allocatable :: gradients(:)
7983
84+ real , pointer :: dw_(:) = > null ()
85+
86+ dw_(1 :size (self % dw)) = > self % dw
87+
8088 gradients = [ &
81- pack (self % dw, .true. ) , &
89+ dw_ , &
8290 self % db &
8391 ]
8492
@@ -87,24 +95,23 @@ end function get_gradients
8795
8896 module subroutine set_params (self , params )
8997 class(dense_layer), intent (in out ) :: self
90- real , intent (in ) :: params(:)
98+ real , intent (in ), target :: params(:)
99+
100+ real , pointer :: p_(:,:) = > null ()
91101
92102 ! check if the number of parameters is correct
93103 if (size (params) /= self % get_num_params()) then
94104 error stop ' Error: number of parameters does not match'
95105 end if
96106
97- ! reshape the weights
98- self % weights = reshape ( &
99- params(:self % input_size * self % output_size), &
100- [self % input_size, self % output_size] &
101- )
102-
103- ! reshape the biases
104- self % biases = reshape ( &
105- params(self % input_size * self % output_size + 1 :), &
106- [self % output_size] &
107- )
107+ associate(n = > self % input_size * self % output_size)
108+ ! reshape the weights
109+ p_(1 :self % input_size, 1 :self % output_size) = > params(1 : n)
110+ self % weights = p_
111+
112+ ! reshape the biases
113+ self % biases = params(n + 1 : n + self % output_size)
114+ end associate
108115
109116 end subroutine set_params
110117
0 commit comments