Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions example/merge_networks.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
program merge_networks
use nf, only: dense, input, network, sgd
use nf_dense_layer, only: dense_layer
implicit none

type(network) :: net1, net2, net3
real, allocatable :: x1(:), x2(:)
real, allocatable :: y1(:), y2(:)
real, allocatable :: y(:)
integer, parameter :: num_iterations = 500
integer :: n, nn
integer :: net1_output_size, net2_output_size

x1 = [0.1, 0.3, 0.5]
x2 = [0.2, 0.4]
y = [0.123456, 0.246802, 0.369258, 0.482604, 0.505050, 0.628406, 0.741852]

net1 = network([ &
input(3), &
dense(2), &
dense(3), &
dense(2) &
])

net2 = network([ &
input(2), &
dense(5), &
dense(3) &
])

net1_output_size = product(net1 % layers(size(net1 % layers)) % layer_shape)
net2_output_size = product(net2 % layers(size(net2 % layers)) % layer_shape)

! Network 3
net3 = network([ &
input(net1_output_size + net2_output_size), &
dense(7) &
])

do n = 1, num_iterations

! Forward propagate two network branches
call net1 % forward(x1)
call net2 % forward(x2)

! Get outputs of net1 and net2, concatenate, and pass to net3
! A helper function could be made to take any number of networks
! and return the concatenated output. Such function would turn the following
! block into a one-liner.
select type (net1_output_layer => net1 % layers(size(net1 % layers)) % p)
type is (dense_layer)
y1 = net1_output_layer % output
end select

select type (net2_output_layer => net2 % layers(size(net2 % layers)) % p)
type is (dense_layer)
y2 = net2_output_layer % output
end select

call net3 % forward([y1, y2])

! Compute the gradients on the 3rd network
call net3 % backward(y)

! net3 % update() will clear the gradients immediately after updating
! the weights, so we need to pass the gradients to net1 and net2 first

! For net1 and net2, we can't use the existing net % backward() because
! it currently assumes that the output layer gradients are computed based
! on the loss function and not the gradient from the next layer.
! For now, we need to manually pass the gradient from the first hidden layer
! of net3 to the output layers of net1 and net2.
select type (next_layer => net3 % layers(2) % p)
! Assume net3's first hidden layer is dense;
! would need to be generalized to others.
type is (dense_layer)

nn = size(net1 % layers)
call net1 % layers(nn) % backward( &
net1 % layers(nn - 1), next_layer % gradient(1:net1_output_size) &
)

nn = size(net2 % layers)
call net2 % layers(nn) % backward( &
net2 % layers(nn - 1), next_layer % gradient(net1_output_size+1:size(next_layer % gradient)) &
)

end select

! Compute the gradients on hidden layers of net1, if any
do nn = size(net1 % layers)-1, 2, -1
select type (next_layer => net1 % layers(nn + 1) % p)
type is (dense_layer)
call net1 % layers(nn) % backward( &
net1 % layers(nn - 1), next_layer % gradient &
)
end select
end do

! Compute the gradients on hidden layers of net2, if any
do nn = size(net2 % layers)-1, 2, -1
select type (next_layer => net2 % layers(nn + 1) % p)
type is (dense_layer)
call net2 % layers(nn) % backward( &
net2 % layers(nn - 1), next_layer % gradient &
)
end select
end do

! Gradients are now computed on all networks and we can update the weights
call net1 % update(optimizer=sgd(learning_rate=1.))
call net2 % update(optimizer=sgd(learning_rate=1.))
call net3 % update(optimizer=sgd(learning_rate=1.))

if (mod(n, 50) == 0) then
print *, "Iteration ", n, ", output RMSE = ", &
sqrt(sum((net3 % predict([net1 % predict(x1), net2 % predict(x2)]) - y)**2) / size(y))
end if

end do

end program merge_networks