Skip to content

Commit 5c9f27c

Browse files
author
tfukumot
committed
Upload files
0 parents  commit 5c9f27c

File tree

9 files changed

+1226
-0
lines changed

9 files changed

+1226
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
%% Import data
2+
rng(0);
3+
data = table2array(readtable("trajectory_training.csv"));
4+
ds = arrayDatastore(dlarray(data',"BC"));
5+
%% Define Network
6+
7+
hiddenSize = 200;
8+
inputSize = 2;
9+
outputSize = 1;
10+
net = [
11+
featureInputLayer(inputSize)
12+
fullyConnectedLayer(hiddenSize)
13+
tanhLayer()
14+
fullyConnectedLayer(hiddenSize)
15+
tanhLayer()
16+
fullyConnectedLayer(outputSize)];
17+
% Create a dlnetwork object from the layer array.
18+
net = dlnetwork(net);
19+
%% Specify Training Options
20+
21+
numEpochs = 300;
22+
miniBatchSize = 750;
23+
executionEnvironment = "auto";
24+
initialLearnRate = 0.001;
25+
decayRate = 1e-4;
26+
27+
%% Create a minibatchque
28+
mbq = minibatchqueue(ds, ...
29+
'MiniBatchSize',miniBatchSize, ...
30+
'MiniBatchFormat','BC', ...
31+
'OutputEnvironment',executionEnvironment);
32+
averageGrad = [];
33+
averageSqGrad = [];
34+
35+
accfun = dlaccelerate(@modelGradients);
36+
37+
figure
38+
C = colororder;
39+
lineLoss = animatedline('Color',C(2,:));
40+
ylim([0 inf])
41+
xlabel("Iteration")
42+
ylabel("Loss")
43+
grid on
44+
set(gca, 'YScale', 'log');
45+
hold off
46+
%% Train model
47+
start = tic;
48+
49+
iteration = 0;
50+
for epoch = 1:numEpochs
51+
shuffle(mbq);
52+
while hasdata(mbq)
53+
iteration = iteration + 1;
54+
55+
dlXT = next(mbq);
56+
dlX = dlXT(1:2,:);
57+
dlT = dlXT(3:4,:);
58+
59+
% Evaluate the model gradients and loss using dlfeval and the
60+
% modelGradients function.
61+
[gradients,loss] = dlfeval(accfun,net,dlX,dlT);
62+
% Update learning rate.
63+
learningRate = initialLearnRate / (1+decayRate*iteration);
64+
65+
% Update the network parameters using the adamupdate function.
66+
[net,averageGrad,averageSqGrad] = adamupdate(net,gradients,averageGrad, ...
67+
averageSqGrad,iteration,learningRate);
68+
end
69+
70+
% Plot training progress.
71+
loss = double(gather(extractdata(loss)));
72+
addpoints(lineLoss,iteration, loss);
73+
74+
drawnow
75+
end
76+
%% Test model
77+
% To make predictions with the Hamiltonian NN we need to solve the ODE system:
78+
% dp/dt = -dH/dq, dq/dt = dH/dp
79+
80+
accOde = dlaccelerate(@predmodel);
81+
t0 = dlarray(0,"CB");
82+
x = dlarray([1,0],"BC");
83+
dlfeval(accOde,t0,x,net);
84+
85+
% Since the original ode45 can't use dlarray we need to write an ODE
86+
% function that wraps accOde by converting the inputs to dlarray, and
87+
% extracting them again after accOde is applied.
88+
f = @(t,x) extractdata(accOde(dlarray(t,"CB"),dlarray(x,"CB"),net));
89+
90+
% Now solve with ode45
91+
x = single([1,0]);
92+
t_span = linspace(0,20,2000);
93+
noise_std =0.1;
94+
% Make predictions.
95+
t_span = t_span.*(1 + .9*noise_std);
96+
[~,dlqp] = ode45(f,t_span,x);
97+
qp = squeeze(double(dlqp));
98+
qp = qp.';
99+
figure,plot(qp(1,:),qp(2,:))
100+
hold on
101+
load qp_baseline.mat
102+
plot(qp(1,:),qp(2,:))
103+
hold off
104+
legend(["Hamiltonian NN","Baseline"])
105+
xlim([-1.1 1.1])
106+
ylim([-1.1 1.1])
107+
%% Supporting Functions
108+
% modelGradients Function
109+
function [gradients,loss] = modelGradients(net,dlX,dlT)
110+
111+
% Make predictions with the initial conditions.
112+
dlU = forward(net,dlX);
113+
[dq,dp] = dlderivative(dlU,dlX);
114+
loss_dq = l2loss(dq,dlT(1,:));
115+
loss_dp = l2loss(dp,dlT(2,:));
116+
loss = loss_dq + loss_dp;
117+
gradients = dlgradient(loss,net.Learnables);
118+
end
119+
120+
% predmodel Function
121+
function dlT_pred = predmodel(t,dlX,net)
122+
dlU = forward(net,dlX);
123+
[dq,dp] = dlderivative(dlU,dlX);
124+
dlT_pred = [dq;dp];
125+
end
126+
127+
% dlderivative Function
128+
function [dq,dp] = dlderivative(F1,dlX)
129+
dF1 = dlgradient(sum(F1,"all"),dlX);
130+
dq = dF1(2,:);
131+
dp = -dF1(1,:);
132+
end
133+
%%
134+
% _Copyright 2023 The MathWorks, Inc._

Demo_baseline_Spring.m

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
%% Import data
2+
rng(0);
3+
data = table2array(readtable("trajectory_training.csv"));
4+
ds = arrayDatastore(dlarray(data',"BC"));
5+
%% Define Network
6+
7+
hiddenSize = 200;
8+
inputSize = 2;
9+
outputSize = 2;
10+
net = [
11+
featureInputLayer(inputSize)
12+
fullyConnectedLayer(hiddenSize)
13+
tanhLayer()
14+
fullyConnectedLayer(hiddenSize)
15+
tanhLayer()
16+
fullyConnectedLayer(outputSize)];
17+
% Create a dlnetwork object from the layer array.
18+
net = dlnetwork(net);
19+
%% Specify Training Options
20+
21+
numEpochs = 300;
22+
miniBatchSize = 750;
23+
executionEnvironment = "auto";
24+
initialLearnRate = 0.001;
25+
decayRate = 1e-4;
26+
27+
%% Create a minibatchque
28+
29+
mbq = minibatchqueue(ds, ...
30+
'MiniBatchSize',miniBatchSize, ...
31+
'MiniBatchFormat','BC', ...
32+
'OutputEnvironment',executionEnvironment);
33+
averageGrad = [];
34+
averageSqGrad = [];
35+
36+
accfun = dlaccelerate(@modelGradients);
37+
38+
figure
39+
C = colororder;
40+
lineLoss = animatedline('Color',C(2,:));
41+
ylim([0 inf])
42+
xlabel("Iteration")
43+
ylabel("Loss")
44+
grid on
45+
set(gca, 'YScale', 'log');
46+
%% Train model
47+
tart = tic;
48+
49+
iteration = 0;
50+
shuffle(mbq);
51+
52+
for epoch = 1:numEpochs
53+
reset(mbq);
54+
% shuffle(mbq);
55+
56+
while hasdata(mbq)
57+
iteration = iteration + 1;
58+
59+
dlXT = next(mbq);
60+
dlX = dlXT(1:2,:);
61+
dlT = dlXT(3:4,:);
62+
63+
% Evaluate the model gradients and loss using dlfeval and the
64+
% modelGradients function.
65+
[gradients,loss] = dlfeval(accfun,net,dlX,dlT);
66+
67+
% Update learning rate.
68+
learningRate = initialLearnRate / (1+decayRate*iteration);
69+
70+
% Update the network parameters using the adamupdate function.
71+
[net,averageGrad,averageSqGrad] = adamupdate(net,gradients,averageGrad, ...
72+
averageSqGrad,iteration,learningRate);
73+
end
74+
75+
% Plot training progress.
76+
loss = double(gather(extractdata(loss)));
77+
addpoints(lineLoss,iteration, loss);
78+
79+
D = duration(0,0,toc(start),'Format','hh:mm:ss');
80+
title("Epoch: " + epoch + ", Elapsed: " + string(D) + ", Loss: " + loss)
81+
drawnow
82+
end
83+
%% Test model
84+
85+
accOde = dlaccelerate(@predmodel);
86+
t0 = dlarray(0,"CB");
87+
x = dlarray([1,0],"BC");
88+
dlfeval(accOde,t0,x,net);
89+
90+
% Since the original ode45 can't use dlarray we need to write an ODE
91+
% function that wraps accOde by converting the inputs to dlarray, and
92+
% extracting them again after accOde is applied.
93+
f = @(t,x) extractdata(accOde(dlarray(t,"CB"),dlarray(x,"CB"),net));
94+
95+
% Now solve with ode45
96+
x = single([1,0]);
97+
t_span = linspace(0,20,2000);
98+
noise_std =0.1;
99+
% Make predictions.
100+
t_span = t_span.*(1 + .9*noise_std);
101+
[~,dlqp] = ode45(f,t_span,x);
102+
qp = squeeze(double(dlqp));
103+
qp = qp.';
104+
figure,plot(qp(1,:),qp(2,:))
105+
106+
%% Supporting Functions
107+
% modelGradients Function
108+
function [gradients,loss] = modelGradients(net,dlX,dlT)
109+
% Make predictions with the initial conditions.
110+
dlT_pred = forward(net,dlX);
111+
112+
loss = mse(dlT_pred,dlT);
113+
% Calculate gradients with respect to the learnable parameters.
114+
gradients = dlgradient(loss,net.Learnables);
115+
116+
end
117+
118+
% predmodel Function
119+
function dlT_pred = predmodel(t,dlX,net)
120+
dlT_pred = forward(net,dlX);
121+
end
122+
%%
123+
% _Copyright 2023 The MathWorks, Inc._

Pics/1.png

67.3 KB
Loading

Pics/2.png

6.5 KB
Loading

0 commit comments

Comments
 (0)