Skip to content

Commit 6cae487

Browse files
Merge pull request #257 from SciML/fm/ps
Adding PartialSquare
2 parents 015293d + 2764498 commit 6cae487

File tree

5 files changed

+77
-3
lines changed

5 files changed

+77
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReservoirComputing"
22
uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294"
33
authors = ["Francesco Martinuzzi"]
4-
version = "0.10.11"
4+
version = "0.10.12"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/src/api/states.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NLAT1
1717
NLAT2
1818
NLAT3
19+
PartialSquare
1920
```
2021

2122
## Internals

src/ReservoirComputing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ include("esn/esn_predict.jl")
3333
include("reca/reca.jl")
3434
include("reca/reca_input_encodings.jl")
3535

36-
export NLADefault, NLAT1, NLAT2, NLAT3
36+
export NLADefault, NLAT1, NLAT2, NLAT3, PartialSquare
3737
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
3838
export StandardRidge
3939
export scaled_rand, weighted_init, informed_init, minimal_init, chebyshev_mapping,

src/states.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,75 @@ function (::NLAT3)(x_old::AbstractVector)
654654

655655
return x_new
656656
end
657+
658+
@doc raw"""
659+
PartialSquare(eta)
660+
661+
Implement a partial squaring of the states as described in [^barbosa2021].
662+
663+
# Equations
664+
665+
```math
666+
\begin{equation}
667+
g(r_i) =
668+
\begin{cases}
669+
r_i^2, & \text{if } i \leq \eta_r N, \\
670+
r_i, & \text{if } i > \eta_r N.
671+
\end{cases}
672+
\end{equation}
673+
```
674+
675+
# Examples
676+
677+
```jldoctest
678+
julia> ps = PartialSquare(0.6)
679+
PartialSquare(0.6)
680+
681+
682+
julia> x_old = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
683+
10-element Vector{Int64}:
684+
0
685+
1
686+
2
687+
3
688+
4
689+
5
690+
6
691+
7
692+
8
693+
9
694+
695+
julia> x_new = ps(x_old)
696+
10-element Vector{Int64}:
697+
0
698+
1
699+
4
700+
9
701+
16
702+
25
703+
6
704+
7
705+
8
706+
9
707+
708+
709+
[^barbosa2021]: Barbosa, Wendson AS, et al.
710+
"Symmetry-aware reservoir computing."
711+
Physical Review E 104.4 (2021): 045307.
712+
"""
713+
struct PartialSquare <: NonLinearAlgorithm
714+
eta::Number
715+
end
716+
717+
function (ps::PartialSquare)(x_old::AbstractVector)
718+
x_new = copy(x_old)
719+
n_length = length(x_old)
720+
threshold = floor(Int, ps.eta * n_length)
721+
for idx in eachindex(x_old)
722+
if idx <= threshold
723+
x_new[idx] = x_old[idx]^2
724+
end
725+
end
726+
727+
return x_new
728+
end

test/test_states.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ test_types = [Float64, Float32, Float16]
88
nlas = [(NLADefault(), test_array),
99
(NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]),
1010
(NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 9]),
11-
(NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9])]
11+
(NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9]),
12+
(PartialSquare(0.6), [1, 4, 9, 16, 25, 6, 7, 8, 9])]
1213

1314
pes = [(StandardStates(), test_array),
1415
(PaddedStates(; padding=padding),

0 commit comments

Comments
 (0)