Skip to content

Commit dce3589

Browse files
removed StatsBase dependency, added error message
1 parent d519c9c commit dce3589

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

Project.toml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
14-
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1514
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1615

1716
[weakdeps]
@@ -39,7 +38,6 @@ Reexport = "1.2.2"
3938
SafeTestsets = "0.1"
4039
SparseArrays = "1.10"
4140
Statistics = "1.10"
42-
StatsBase = "0.34.4"
4341
Test = "1"
4442
WeightInitializers = "1.0.5"
4543
julia = "1.10"
@@ -49,11 +47,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4947
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
5048
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
5149
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
52-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5350
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
51+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5452
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
5553
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5654

5755
[targets]
58-
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations",
59-
"MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays"]
56+
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays"]

src/ReservoirComputing.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@ using CellularAutomata: CellularAutomaton
55
using Compat: @compat
66
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal
77
using NNlib: fast_act, sigmoid
8-
using Random: Random, AbstractRNG
8+
using Random: Random, AbstractRNG, randperm
99
using Reexport: Reexport, @reexport
10-
using StatsBase: sample
1110
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1211
@reexport using WeightInitializers
1312

@@ -44,7 +43,7 @@ export rand_sparse, delay_line, delay_line_backward, cycle_jumps,
4443
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
4544
export train
4645
export ESN, HybridESN, KnowledgeModel, DeepESN
47-
export RECA, sample
46+
export RECA
4847
export RandomMapping, RandomMaps
4948
export Generative, Predictive, OutputLayer
5049

src/esn/esn_inits.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,19 @@ function return_init_as(::Val{false}, layer_matrix::AbstractVecOrMat)
33
return layer_matrix
44
end
55

6+
# error for sparse inits with no SparseArrays.jl call
7+
8+
function throw_sparse_error(return_sparse)
9+
if return_sparse && !haskey(Base.loaded_modules, :SparseArrays)
10+
error("""\n
11+
Sparse output requested but SparseArrays.jl is not loaded.
12+
Please load it with:
13+
14+
using SparseArrays\n
15+
""")
16+
end
17+
end
18+
619
### input layers
720
"""
821
scaled_rand([rng], [T], dims...;
@@ -91,6 +104,7 @@ julia> res_input = weighted_init(8, 3)
91104
"""
92105
function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
93106
scaling=T(0.1), return_sparse::Bool=false) where {T <: Number}
107+
throw_sparse_error(return_sparse)
94108
approx_res_size, in_size = dims
95109
res_size = Int(floor(approx_res_size / in_size) * in_size)
96110
layer_matrix = DeviceAgnostic.zeros(rng, T, res_size, in_size)
@@ -354,6 +368,7 @@ function chebyshev_mapping(rng::AbstractRNG, ::Type{T}, dims::Integer...;
354368
amplitude::AbstractFloat=one(T), sine_divisor::AbstractFloat=one(T),
355369
chebyshev_parameter::AbstractFloat=one(T),
356370
return_sparse::Bool=false) where {T <: Number}
371+
throw_sparse_error(return_sparse)
357372
input_matrix = DeviceAgnostic.zeros(rng, T, dims...)
358373
n_rows, n_cols = dims[1], dims[2]
359374

@@ -433,6 +448,7 @@ function logistic_mapping(rng::AbstractRNG, ::Type{T}, dims::Integer...;
433448
amplitude::AbstractFloat=0.3, sine_divisor::AbstractFloat=5.9,
434449
logistic_parameter::AbstractFloat=3.7,
435450
return_sparse::Bool=false) where {T <: Number}
451+
throw_sparse_error(return_sparse)
436452
input_matrix = DeviceAgnostic.zeros(rng, T, dims...)
437453
num_rows, num_columns = dims[1], dims[2]
438454
for col in 1:num_columns
@@ -533,6 +549,7 @@ function modified_lm(rng::AbstractRNG, ::Type{T}, dims::Integer...;
533549
factor::Integer, amplitude::AbstractFloat=0.3,
534550
sine_divisor::AbstractFloat=5.9, logistic_parameter::AbstractFloat=2.35,
535551
return_sparse::Bool=false) where {T <: Number}
552+
throw_sparse_error(return_sparse)
536553
num_columns = dims[2]
537554
expected_num_rows = factor * num_columns
538555
if dims[1] != expected_num_rows
@@ -599,6 +616,7 @@ julia> res_matrix = rand_sparse(5, 5; sparsity=0.5)
599616
function rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...;
600617
radius=T(1.0), sparsity=T(0.1), std=T(1.0),
601618
return_sparse::Bool=false) where {T <: Number}
619+
throw_sparse_error(return_sparse)
602620
lcl_sparsity = T(1) - sparsity #consistency with current implementations
603621
reservoir_matrix = sparse_init(rng, T, dims...; sparsity=lcl_sparsity, std=std)
604622
rho_w = maximum(abs.(eigvals(reservoir_matrix)))
@@ -660,6 +678,7 @@ julia> res_matrix = delay_line(5, 5; weight=1)
660678
"""
661679
function delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...;
662680
weight=T(0.1), return_sparse::Bool=false) where {T <: Number}
681+
throw_sparse_error(return_sparse)
663682
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
664683
@assert length(dims) == 2&&dims[1] == dims[2] """\n
665684
The dimensions must define a square matrix
@@ -723,6 +742,7 @@ julia> res_matrix = delay_line_backward(Float16, 5, 5)
723742
"""
724743
function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
725744
weight=T(0.1), fb_weight=T(0.2), return_sparse::Bool=false) where {T <: Number}
745+
throw_sparse_error(return_sparse)
726746
res_size = first(dims)
727747
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
728748

@@ -787,6 +807,7 @@ julia> res_matrix = cycle_jumps(5, 5; jump_size=2)
787807
function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
788808
cycle_weight::Number=T(0.1), jump_weight::Number=T(0.1),
789809
jump_size::Int=3, return_sparse::Bool=false) where {T <: Number}
810+
throw_sparse_error(return_sparse)
790811
res_size = first(dims)
791812
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
792813

@@ -854,6 +875,7 @@ julia> res_matrix = simple_cycle(5, 5; weight=11)
854875
"""
855876
function simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
856877
weight=T(0.1), return_sparse::Bool=false) where {T <: Number}
878+
throw_sparse_error(return_sparse)
857879
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
858880

859881
for i in 1:(dims[1] - 1)
@@ -916,6 +938,7 @@ function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
916938
max_value::Number=T(1.0), sparsity::Number=0.1, sorted::Bool=true,
917939
reverse_sort::Bool=false, return_sparse::Bool=false,
918940
return_diag::Bool=false) where {T <: Number}
941+
throw_sparse_error(return_sparse)
919942
reservoir_matrix = create_diag(rng, T, dims[1],
920943
max_value;
921944
sorted=sorted,
@@ -1039,6 +1062,7 @@ julia> res_matrix = chaotic_init(8, 8)
10391062
function chaotic_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
10401063
extra_edge_probability::AbstractFloat=T(0.1), spectral_radius::AbstractFloat=one(T),
10411064
return_sparse::Bool=false) where {T <: Number}
1065+
throw_sparse_error(return_sparse)
10421066
requested_order = first(dims)
10431067
if length(dims) > 1 && dims[2] != requested_order
10441068
@warn """\n

src/reca/reca_input_encodings.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,6 @@ function init_maps(input_size, permutations, mapped_vector_size)
102102
end
103103

104104
function mapping(input_size, mapped_vector_size)
105-
return sample(1:mapped_vector_size, input_size; replace=false)
105+
#sample(1:mapped_vector_size, input_size; replace=false)
106+
return randperm(mapped_vector_size)[1:input_size]
106107
end

0 commit comments

Comments
 (0)