Skip to content

Commit 7c0ae70

Browse files
bulk implementation of PartialSquare
1 parent 015293d commit 7c0ae70

File tree

4 files changed

+21
-2
lines changed

4 files changed

+21
-2
lines changed

docs/src/api/states.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NLAT1
1717
NLAT2
1818
NLAT3
19+
PartialSquare
1920
```
2021

2122
## Internals

src/ReservoirComputing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ include("esn/esn_predict.jl")
3333
include("reca/reca.jl")
3434
include("reca/reca_input_encodings.jl")
3535

36-
export NLADefault, NLAT1, NLAT2, NLAT3
36+
export NLADefault, NLAT1, NLAT2, NLAT3, PartialSquare
3737
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
3838
export StandardRidge
3939
export scaled_rand, weighted_init, informed_init, minimal_init, chebyshev_mapping,

src/states.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,20 @@ function (::NLAT3)(x_old::AbstractVector)
654654

655655
return x_new
656656
end
657+
658+
struct PartialSquare <: NonLinearAlgorithm
659+
eta::Number
660+
end
661+
662+
function (ps::PartialSquare)(x_old::AbstractVector)
663+
x_new = copy(x_old)
664+
n_length = length(x_old)
665+
threshold = floor(Int, ps.eta * n_length)
666+
for idx in eachindex(x_old)
667+
if idx <= threshold
668+
x_new[idx] = x_old[idx]^2
669+
end
670+
end
671+
672+
return x_new
673+
end

test/test_states.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ test_types = [Float64, Float32, Float16]
88
nlas = [(NLADefault(), test_array),
99
(NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]),
1010
(NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 9]),
11-
(NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9])]
11+
(NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9]),
12+
(PartialSquare(0.6), [1, 4, 9, 16, 25, 6, 7, 8, 9])]
1213

1314
pes = [(StandardStates(), test_array),
1415
(PaddedStates(; padding=padding),

0 commit comments

Comments
 (0)