@@ -12,6 +12,8 @@ struct HybridESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork
1212 states:: IS
1313end
1414
15+ const AbstractDriver = Union{AbstractReservoirDriver, GRU}
16+
1517struct KnowledgeModel{T, K, O, I, S, D}
1618 prior_model:: T
1719 u0:: K
@@ -91,19 +93,12 @@ traditional Echo State Networks with a predefined knowledge model [^Pathak2018].
9193 "Hybrid Forecasting of Chaotic Processes:
9294 Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
9395"""
94- function HybridESN (model,
95- train_data,
96- in_size:: Int ,
97- res_size:: Int ;
98- input_layer= scaled_rand,
99- reservoir= rand_sparse,
100- bias= zeros32,
101- reservoir_driver= RNN (),
102- nla_type= NLADefault (),
103- states_type= StandardStates (),
104- washout= 0 ,
105- rng= Utils. default_rng (),
106- T= Float32,
96+ function HybridESN (model:: KnowledgeModel , train_data:: AbstractArray ,
97+ in_size:: Int , res_size:: Int ; input_layer= scaled_rand, reservoir= rand_sparse,
98+ bias= zeros32, reservoir_driver:: AbstractDriver = RNN (),
99+ nla_type:: NonLinearAlgorithm = NLADefault (),
100+ states_type:: AbstractStates = StandardStates (), washout:: Int = 0 ,
101+ rng:: AbstractRNG = Utils. default_rng (), T= Float32,
107102 matrix_type= typeof (train_data))
108103 train_data = vcat (train_data, model. model_data[:, 1 : (end - 1 )])
109104
@@ -130,8 +125,7 @@ function HybridESN(model,
130125end
131126
132127function (hesn:: HybridESN )(prediction:: AbstractPrediction ,
133- output_layer:: AbstractOutputLayer ;
134- last_state= hesn. states[:, [end ]],
128+ output_layer:: AbstractOutputLayer ; last_state:: AbstractArray = hesn. states[:, [end ]],
135129 kwargs... )
136130 km = hesn. model
137131 pred_len = prediction. prediction_len
@@ -148,10 +142,8 @@ function (hesn::HybridESN)(prediction::AbstractPrediction,
148142 kwargs... )
149143end
150144
151- function train (hesn:: HybridESN ,
152- target_data,
153- training_method= StandardRidge ();
154- kwargs... )
145+ function train (hesn:: HybridESN , target_data:: AbstractArray ,
146+ training_method= StandardRidge (); kwargs... )
155147 states = vcat (hesn. states, hesn. model. model_data[:, 2 : end ])
156148 states_new = hesn. states_type (hesn. nla_type, states, hesn. train_data[:, 1 : end ])
157149
0 commit comments