Skip to content
This repository was archived by the owner on Dec 11, 2022. It is now read-only.

Commit e983b59

Browse files
committed
🧹 no need to call SSDML.transform
1 parent 5dc61c9 commit e983b59

File tree

3 files changed

+30
-32
lines changed

3 files changed

+30
-32
lines changed

src/SimpleSDMLayers.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ isdir(_layers_assets_path) || mkpath(_layers_assets_path)
9999
export clip
100100

101101
function __init__()
102-
@require GBIF="ee291a33-5a6c-5552-a3c8-0f29a1181037" begin
102+
@require GBIF = "ee291a33-5a6c-5552-a3c8-0f29a1181037" begin
103103
@info "Loading GBIF support for SimpleSDMLayers.jl"
104104
include("integrations/GBIF.jl")
105105
end
@@ -111,7 +111,6 @@ function __init__()
111111
@require MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" begin
112112
@info "Loading MultivariateStats support for SimpleSDMLayers.jl"
113113
include("integrations/MultivariateStats.jl")
114-
export transform
115114
end
116115

117116
end
Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,44 @@
11
# WARNING this file is only loaded if MultivariateStats.jl is also active
22
# This all happens thanks to the Requires.jl package
33

4+
_allowed_transforms = (MultivariateStats.PCA, MultivariateStats.PPCA, MultivariateStats.KernelPCA, MultivariateStats.Whitening)
5+
AllowedMultivariateTransforms = Union{_allowed_transforms...}
6+
47
"""
58
MultivariateStats.fit(a, layers::Vector{T}, kwargs...) where T <: SimpleSDMLayer
69
7-
Overloads the `fit` function from `MultivariateStats.jl`.
10+
Overloads the `fit` function from `MultivariateStats.jl`.
811
"""
9-
function MultivariateStats.fit(a, layers::Vector{T}, kwargs...) where {T<:SimpleSDMLayer}
12+
function MultivariateStats.fit(proj::Type{K}, layers::Vector{T}; kwargs...) where {T<:SimpleSDMLayer, K <: AllowedMultivariateTransforms}
1013
_layers_are_compatible(layers) || return ArgumentError("layers are not compatible")
1114
common_keys = reduce(, keys.(layers))
1215
input = hcat([vcat([layer[key] for layer in layers]...) for key in common_keys]...)
13-
proj = MultivariateStats.fit(a, input, kwargs...)
16+
proj = MultivariateStats.fit(proj, input; kwargs...)
1417
return proj
1518
end
1619

1720
"""
1821
transform(proj, layers::Vector{V},
1922
kwargs...) where {PT<:Union{MultivariateStats.PCA, MultivariateStats.PPCA},V<:SimpleSDMLayer}
2023
21-
Overload of the `transform` function from `MultivariateStats.jl`. Here `proj` is a
22-
and output object from `MultivariateStats.fit` (see above).
24+
Overload of the `transform` function from `MultivariateStats.jl`. Here `proj` is an output object from `MultivariateStats.fit` (see above).
2325
"""
24-
function transform(
25-
proj,
26-
layers::AbstractVecOrMat{U};
27-
kwargs...,
26+
function MultivariateStats.transform(
27+
proj, layers::AbstractVecOrMat{U}; kwargs...
2828
) where {U<:SimpleSDMLayer}
29-
outdim = MultivariateStats.outdim(proj)
30-
newlayers = [similar(layers[begin]) for i = 1:outdim]
29+
_layers_are_compatible(layers) || return ArgumentError("layers are not compatible")
30+
31+
newlayers = [similar(first(layers)) for i in 1:MultivariateStats.outdim(proj)]
3132
common_keys = reduce(, keys.(layers))
3233

3334
input = hcat([vcat([layer[key] for layer in layers]...) for key in common_keys]...)
3435

3536
for (ct, key) in enumerate(common_keys)
3637
pcaproj = MultivariateStats.transform(proj, input[:, ct])
37-
for i = 1:outdim
38+
for i in 1:length(newlayers)
3839
newlayers[i][key] = pcaproj[i]
3940
end
4041
end
42+
4143
return newlayers
4244
end
43-
44-
Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
module SSLTestMVStats
2-
using SimpleSDMLayers
3-
using MultivariateStats
4-
using NeutralLandscapes
5-
using Test
2+
using SimpleSDMLayers
3+
using MultivariateStats
4+
using NeutralLandscapes
5+
using Test
66

7-
TEST_DIMS = (150,150)
8-
TEST_AUTOCORRELATION = 0.9
9-
TEST_NUM_LAYERS = 10
7+
TEST_DIMS = (150, 150)
8+
TEST_AUTOCORRELATION = 0.9
9+
TEST_NUM_LAYERS = 10
1010

11-
layers = [SimpleSDMResponse(rand(MidpointDisplacement(TEST_AUTOCORRELATION), TEST_DIMS...)) for i in 1:TEST_NUM_LAYERS]
11+
layers = [
12+
SimpleSDMResponse(rand(MidpointDisplacement(TEST_AUTOCORRELATION), TEST_DIMS...)) for
13+
i in 1:TEST_NUM_LAYERS
14+
]
1215

16+
@test typeof(transform(fit(PCA, layers), layers)) <: Vector{T} where {T<:SimpleSDMLayer}
17+
#@test typeof(transform(fit(PPCA, layers),layers)) <: Vector{T} where T<:SimpleSDMLayer
1318

14-
@test typeof(SimpleSDMLayers.transform(fit(PCA, layers),layers)) <: Vector{T} where T<:SimpleSDMLayer
15-
# @test typeof(transform(fit(PPCA, layers),layers)) <: Vector{T} where T<:SimpleSDMLayer
16-
17-
18-
19-
# @assert typeof(transform(fit(KernelPCA, layers),layers)) <: Vector{T} where T<:SimpleSDMLayer
20-
# @assert typeof(transform(fit(Whitening, layers),layers)) <: Vector{T} where T<:SimpleSDMLayer
19+
#@assert typeof(transform(fit(KernelPCA, layers),layers)) <: Vector{T} where T<:SimpleSDMLayer
20+
#@assert typeof(transform(fit(Whitening, layers),layers)) <: Vector{T} where T<:SimpleSDMLayer
2121

2222
end
23-

0 commit comments

Comments
 (0)