|
1 | 1 | # WARNING this file is only loaded if MultivariateStats.jl is also active |
2 | 2 | # This all happens thanks to the Requires.jl package |
3 | 3 |
|
| 4 | +_allowed_transforms = (MultivariateStats.PCA, MultivariateStats.PPCA, MultivariateStats.KernelPCA, MultivariateStats.Whitening) |
| 5 | +AllowedMultivariateTransforms = Union{_allowed_transforms...} |
| 6 | + |
4 | 7 | """ |
5 | 8 | MultivariateStats.fit(a, layers::Vector{T}, kwargs...) where T <: SimpleSDMLayer |
6 | 9 |
|
7 | | - Overloads the `fit` function from `MultivariateStats.jl`. |
| 10 | +Overloads the `fit` function from `MultivariateStats.jl`. |
8 | 11 | """ |
9 | | -function MultivariateStats.fit(a, layers::Vector{T}, kwargs...) where {T<:SimpleSDMLayer} |
| 12 | +function MultivariateStats.fit(proj::Type{K}, layers::Vector{T}; kwargs...) where {T<:SimpleSDMLayer, K <: AllowedMultivariateTransforms} |
10 | 13 | _layers_are_compatible(layers) || return ArgumentError("layers are not compatible") |
11 | 14 | common_keys = reduce(∩, keys.(layers)) |
12 | 15 | input = hcat([vcat([layer[key] for layer in layers]...) for key in common_keys]...) |
13 | | - proj = MultivariateStats.fit(a, input, kwargs...) |
| 16 | + proj = MultivariateStats.fit(proj, input; kwargs...) |
14 | 17 | return proj |
15 | 18 | end |
16 | 19 |
|
17 | 20 | """ |
18 | 21 | transform(proj, layers::Vector{V}, |
19 | 22 | kwargs...) where {PT<:Union{MultivariateStats.PCA, MultivariateStats.PPCA},V<:SimpleSDMLayer} |
20 | 23 |
|
21 | | - Overload of the `transform` function from `MultivariateStats.jl`. Here `proj` is a |
22 | | - and output object from `MultivariateStats.fit` (see above). |
| 24 | +Overload of the `transform` function from `MultivariateStats.jl`. Here `proj` is an output object from `MultivariateStats.fit` (see above). |
23 | 25 | """ |
24 | | -function transform( |
25 | | - proj, |
26 | | - layers::AbstractVecOrMat{U}; |
27 | | - kwargs..., |
| 26 | +function MultivariateStats.transform( |
| 27 | + proj, layers::AbstractVecOrMat{U}; kwargs... |
28 | 28 | ) where {U<:SimpleSDMLayer} |
29 | | - outdim = MultivariateStats.outdim(proj) |
30 | | - newlayers = [similar(layers[begin]) for i = 1:outdim] |
| 29 | + _layers_are_compatible(layers) || return ArgumentError("layers are not compatible") |
| 30 | + |
| 31 | + newlayers = [similar(first(layers)) for i in 1:MultivariateStats.outdim(proj)] |
31 | 32 | common_keys = reduce(∩, keys.(layers)) |
32 | 33 |
|
33 | 34 | input = hcat([vcat([layer[key] for layer in layers]...) for key in common_keys]...) |
34 | 35 |
|
35 | 36 | for (ct, key) in enumerate(common_keys) |
36 | 37 | pcaproj = MultivariateStats.transform(proj, input[:, ct]) |
37 | | - for i = 1:outdim |
| 38 | + for i in 1:length(newlayers) |
38 | 39 | newlayers[i][key] = pcaproj[i] |
39 | 40 | end |
40 | 41 | end |
| 42 | + |
41 | 43 | return newlayers |
42 | 44 | end |
43 | | - |
44 | | - |
0 commit comments