Skip to content

Commit c916c45

Browse files
Merge pull request #136 from SciML/fm/states_pred
Adding `save_states`
2 parents 4814d55 + acf5776 commit c916c45

File tree

5 files changed

+25
-9
lines changed

5 files changed

+25
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReservoirComputing"
22
uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294"
33
authors = ["Francesco Martinuzzi"]
4-
version = "0.9.0"
4+
version = "0.9.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/src/esn_tutorials/lorenz_basic.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ Once the ```OutputLayer``` has been obtained the prediction can be done followin
8686
output = esn(Generative(predict_len), output_layer)
8787
```
8888
both the training method and the output layer are needed in this call. The number of steps for the prediction must be specified to the ```Generative``` method. The output results are given in a matrix.
89+
90+
!!! info "Saving the states during prediction"
91+
While the states are saved in the `ESN` struct for the training, for the prediction they are not saved by default. To inspect the states it is necessary to pass the boolean keyword argument `save_states` to the prediction call, in this example using `esn(... ; save_states=true)`. This returns a tuple `(output, states)` where `size(states) = res_size, prediction_len`
92+
8993
To inspect the results they can easily be plotted using an external library. In this case ```Plots``` is adopted:
9094
```julia
9195
using Plots, Plots.PlotMeasures

src/esn/echostatenetwork.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ end
164164

165165
function (esn::ESN)(prediction::AbstractPrediction,
166166
output_layer::AbstractOutputLayer;
167-
initial_conditions = output_layer.last_value,
168-
last_state = esn.states[:, [end]])
167+
last_state = esn.states[:, [end]],
168+
kwargs...)
169169
variation = esn.variation
170170
pred_len = prediction.prediction_len
171171

@@ -178,10 +178,10 @@ function (esn::ESN)(prediction::AbstractPrediction,
178178
model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end]
179179
return obtain_esn_prediction(esn, prediction, last_state, output_layer,
180180
model_pred_data;
181-
initial_conditions = initial_conditions)
181+
kwargs...)
182182
else
183183
return obtain_esn_prediction(esn, prediction, last_state, output_layer;
184-
initial_conditions = initial_conditions)
184+
kwargs...)
185185
end
186186
end
187187

src/esn/esn_predict.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ function obtain_esn_prediction(esn,
33
x,
44
output_layer,
55
args...;
6-
initial_conditions = output_layer.last_value)
6+
initial_conditions = output_layer.last_value,
7+
save_states = false)
78
out_size = output_layer.out_size
89
training_method = output_layer.training_method
910
prediction_len = prediction.prediction_len
1011

1112
output = output_storing(training_method, out_size, prediction_len, typeof(esn.states))
1213
out = initial_conditions
14+
states = similar(esn.states, size(esn.states, 1), prediction_len)
1315

1416
out_pad = allocate_outpad(esn.variation, esn.states_type, out)
1517
tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size)
@@ -20,23 +22,26 @@ function obtain_esn_prediction(esn,
2022
args...)
2123
out_tmp = get_prediction(output_layer.training_method, output_layer, x_new)
2224
out = store_results!(output_layer.training_method, out_tmp, output, i)
25+
states[:, i] = x
2326
end
2427

25-
return output
28+
save_states ? (output, states) : output
2629
end
2730

2831
function obtain_esn_prediction(esn,
2932
prediction::Predictive,
3033
x,
3134
output_layer,
3235
args...;
33-
initial_conditions = output_layer.last_value)
36+
initial_conditions = output_layer.last_value,
37+
save_states = false)
3438
out_size = output_layer.out_size
3539
training_method = output_layer.training_method
3640
prediction_len = prediction.prediction_len
3741

3842
output = output_storing(training_method, out_size, prediction_len, typeof(esn.states))
3943
out = initial_conditions
44+
states = similar(esn.states, size(esn.states, 1), prediction_len)
4045

4146
out_pad = allocate_outpad(esn.variation, esn.states_type, out)
4247
tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size)
@@ -47,9 +52,10 @@ function obtain_esn_prediction(esn,
4752
out_pad, i, tmp_array, args...)
4853
out_tmp = get_prediction(training_method, output_layer, x_new)
4954
out = store_results!(training_method, out_tmp, output, i)
55+
states[:, i] = x
5056
end
5157

52-
return output
58+
save_states ? (output, states) : output
5359
end
5460

5561
#prediction dispatch on esn

test/esn/test_train.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,9 @@ for t in training_methods
2626
output = esn(Predictive(input_data), output_layer)
2727
@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.21
2828
end
29+
30+
for t in training_methods
31+
output_layer = train(esn, target_data, t)
32+
output, states = esn(Predictive(input_data), output_layer, save_states = true)
33+
@test size(states) == (res_size, size(input_data, 2))
34+
end

0 commit comments

Comments
 (0)