1111 use nf_keras, only: get_keras_h5_layers, keras_layer
1212 use nf_layer, only: layer
1313 use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
14- use nf_loss, only: quadratic_derivative
14+ use nf_loss, only: quadratic
1515 use nf_optimizers, only: optimizer_base_type, sgd
1616 use nf_parallel, only: tile_indices
1717 use nf_activation, only: activation_function, &
@@ -280,11 +280,27 @@ pure function get_activation_by_name(activation_name) result(res)
280280
281281 end function get_activation_by_name
282282
283- pure module subroutine backward(self, output)
283+ pure module subroutine backward(self, output, loss )
284284 class(network), intent (in out ) :: self
285285 real , intent (in ) :: output(:)
286+ class(loss_type), intent (in ), optional :: loss
286287 integer :: n, num_layers
287288
289+ ! Passing the loss instance is optional. If not provided, and if the
290+ ! loss instance has not already been set, we default to the default quadratic. The
291+ ! instantiation and initialization below of the loss instance is normally done
292+ ! at the beginning of the network % train() method. However, if the user
293+ ! wants to call network % backward() directly, for example if they use their
294+ ! own custom mini-batching routine, we initialize the loss instance here as
295+ ! well. If it's initialized already, this step is a cheap no-op.
296+ if (.not. allocated (self % loss)) then
297+ if (present (loss)) then
298+ self % loss = loss
299+ else
300+ self % loss = quadratic()
301+ end if
302+ end if
303+
288304 num_layers = size (self % layers)
289305
290306 ! Iterate backward over layers, from the output layer
@@ -297,7 +313,7 @@ pure module subroutine backward(self, output)
297313 type is (dense_layer)
298314 call self % layers(n) % backward( &
299315 self % layers(n - 1 ), &
300- quadratic_derivative (output, this_layer % output) &
316+ self % loss % derivative (output, this_layer % output) &
301317 )
302318 end select
303319 else
@@ -542,13 +558,14 @@ end subroutine set_params
542558
543559
544560 module subroutine train (self , input_data , output_data , batch_size , &
545- epochs , optimizer )
561+ epochs , optimizer , loss )
546562 class(network), intent (in out ) :: self
547563 real , intent (in ) :: input_data(:,:)
548564 real , intent (in ) :: output_data(:,:)
549565 integer , intent (in ) :: batch_size
550566 integer , intent (in ) :: epochs
551567 class(optimizer_base_type), intent (in ), optional :: optimizer
568+ class(loss_type), intent (in ), optional :: loss
552569 class(optimizer_base_type), allocatable :: optimizer_
553570
554571 real :: pos
@@ -567,6 +584,14 @@ module subroutine train(self, input_data, output_data, batch_size, &
567584
568585 call self % optimizer % init(self % get_num_params())
569586
587+ ! Passing the loss instance is optional.
588+ ! If not provided, we default to quadratic().
589+ if (present (loss)) then
590+ self % loss = loss
591+ else
592+ self % loss = quadratic()
593+ end if
594+
570595 dataset_size = size (output_data, dim= 2 )
571596
572597 epoch_loop: do n = 1 , epochs
0 commit comments