Skip to content

Commit b6fd59e

Browse files
committed
Eig support
1 parent a76e240 commit b6fd59e

File tree

3 files changed

+121
-102
lines changed

3 files changed

+121
-102
lines changed

test/eig.jl

Lines changed: 22 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -8,109 +8,29 @@ using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
88
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
99
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
1010

11-
@testset "eig_full! for T = $T" for T in BLASFloats
12-
rng = StableRNG(123)
13-
m = 54
14-
for alg in (LAPACK_Simple(), LAPACK_Expert(), :LAPACK_Simple, LAPACK_Simple)
15-
A = randn(rng, T, m, m)
16-
Tc = complex(T)
11+
using CUDA, AMDGPU
1712

18-
D, V = @constinferred eig_full(A; alg = ($alg))
19-
@test eltype(D) == eltype(V) == Tc
20-
@test A * V V * D
21-
22-
alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg)
23-
24-
Ac = similar(A)
25-
D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′)
26-
@test D2 === D
27-
@test V2 === V
28-
@test A * V V * D
29-
30-
Dc = @constinferred eig_vals(A, alg′)
31-
@test eltype(Dc) == Tc
32-
@test D Diagonal(Dc)
33-
end
34-
end
35-
36-
@testset "eig_trunc! for T = $T" for T in BLASFloats
37-
rng = StableRNG(123)
38-
m = 54
39-
for alg in (LAPACK_Simple(), LAPACK_Expert())
40-
A = randn(rng, T, m, m)
41-
A *= A' # TODO: deal with eigenvalue ordering etc
42-
# eigenvalues are sorted by ascending real component...
43-
D₀ = sort!(eig_vals(A); by = abs, rev = true)
44-
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
45-
r = length(D₀) - rmin
46-
atol = sqrt(eps(real(T)))
47-
48-
D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r))
49-
@test length(diagview(D1)) == r
50-
@test A * V1 V1 * D1
51-
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
52-
53-
s = 1 + sqrt(eps(real(T)))
54-
trunc = trunctol(; atol = s * abs(D₀[r + 1]))
55-
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc)
56-
@test length(diagview(D2)) == r
57-
@test A * V2 V2 * D2
58-
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
59-
60-
s = 1 - sqrt(eps(real(T)))
61-
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
62-
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc)
63-
@test length(diagview(D3)) == r
64-
@test A * V3 V3 * D3
65-
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
66-
67-
# trunctol keeps order, truncrank might not
68-
# test for same subspace
69-
@test V1 * ((V1' * V1) \ (V1' * V2)) V2
70-
@test V2 * ((V2' * V2) \ (V2' * V1)) V1
71-
@test V1 * ((V1' * V1) \ (V1' * V3)) V3
72-
@test V3 * ((V3' * V3) \ (V3' * V1)) V1
13+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
14+
GenericFloats = (Float16,) #BigFloat, Complex{BigFloat})
15+
16+
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
17+
using .TestSuite
18+
19+
m = 54
20+
for T in BLASFloats
21+
TestSuite.seed_rng!(123)
22+
TestSuite.test_eig(T, (m, m))
23+
if CUDA.functional()
24+
TestSuite.test_eig(CuMatrix{T}, (m, m); test_blocksize = false)
25+
TestSuite.test_eig(Diagonal{T, CuVector{T}}, m; test_blocksize = false)
7326
end
27+
#= not yet supported
28+
if AMDGPU.functional()
29+
TestSuite.test_eig(ROCMatrix{T}, (m, m); test_blocksize = false)
30+
TestSuite.test_eig(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
31+
end=#
7432
end
75-
76-
@testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats
77-
rng = StableRNG(123)
78-
m = 4
79-
atol = sqrt(eps(real(T)))
80-
V = randn(rng, T, m, m)
81-
D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
82-
A = V * D * inv(V)
83-
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
84-
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg)
85-
@test diagview(D2) diagview(D)[1:2]
86-
@test ϵ2 norm(diagview(D)[3:4]) atol = atol
87-
@test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2))
88-
89-
alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
90-
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg)
91-
@test diagview(D3) diagview(D)[1:2]
92-
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
93-
end
94-
95-
@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
96-
rng = StableRNG(123)
97-
m = 54
98-
Ad = randn(rng, T, m)
99-
A = Diagonal(Ad)
100-
atol = sqrt(eps(real(T)))
101-
102-
D, V = @constinferred eig_full(A)
103-
@test D isa Diagonal{T} && size(D) == size(A)
104-
@test V isa Diagonal{T} && size(V) == size(A)
105-
@test A * V V * D
106-
107-
D2 = @constinferred eig_vals(A)
108-
@test D2 isa AbstractVector{T} && length(D2) == m
109-
@test diagview(D) D2
110-
111-
A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
112-
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
113-
D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg)
114-
@test diagview(D2) diagview(A2)[1:2]
115-
@test ϵ2 norm(diagview(A2)[3:4]) atol = atol
33+
for T in (BLASFloats..., GenericFloats...)
34+
AT = Diagonal{T, Vector{T}}
35+
TestSuite.test_eig(AT, m; test_blocksize = false)
11636
end

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,6 @@ include("polar.jl")
6969
include("orthnull.jl")
7070
include("projections.jl")
7171
include("eigh.jl")
72+
include("eig.jl")
7273

7374
end

test/testsuite/eig.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
using TestExtras
2+
3+
function test_eig(T::Type, sz; kwargs...)
4+
summary_str = testargs_summary(T, sz)
5+
return @testset "eig $summary_str" begin
6+
test_eig_full(T, sz; kwargs...)
7+
test_eig_trunc(T, sz; kwargs...)
8+
end
9+
end
10+
11+
function test_eig_full(
12+
T::Type, sz;
13+
test_blocksize = true,
14+
atol::Real = 0, rtol::Real = precision(T),
15+
kwargs...
16+
)
17+
summary_str = testargs_summary(T, sz)
18+
return @testset "eig_full! $summary_str" begin
19+
A = instantiate_matrix(T, sz)
20+
Ac = deepcopy(A)
21+
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
22+
D, V = @testinferred eig_full(A)
23+
@test eltype(D) == eltype(V) == Tc
24+
@test A * V V * D
25+
26+
D2, V2 = @testinferred eig_full!(Ac, (D, V))
27+
@test D2 === D
28+
@test V2 === V
29+
@test A * V V * D
30+
31+
Dc = @testinferred eig_vals(A)
32+
@test eltype(Dc) == Tc
33+
@test D Diagonal(Dc)
34+
end
35+
end
36+
37+
function test_eig_trunc(
38+
T::Type, sz;
39+
test_blocksize = true,
40+
atol::Real = 0, rtol::Real = precision(T),
41+
kwargs...
42+
)
43+
summary_str = testargs_summary(T, sz)
44+
return @testset "eig_trunc! $summary_str" begin
45+
A = instantiate_matrix(T, sz)
46+
A *= A' # TODO: deal with eigenvalue ordering etc
47+
Ac = deepcopy(A)
48+
Tc = complex(eltype(T))
49+
# eigenvalues are sorted by ascending real component...
50+
D₀ = sort!(eig_vals(A); by = abs, rev = true)
51+
m = size(A, 1)
52+
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
53+
r = length(D₀) - rmin
54+
atol = sqrt(eps(real(eltype(T))))
55+
56+
D1, V1, ϵ1 = @testinferred eig_trunc(A; trunc = truncrank(r))
57+
@test length(diagview(D1)) == r
58+
@test A * V1 V1 * D1
59+
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
60+
61+
s = 1 + sqrt(eps(real(eltype(T))))
62+
trunc = trunctol(; atol = s * abs(D₀[r + 1]))
63+
D2, V2, ϵ2 = @testinferred eig_trunc(A; trunc)
64+
@test length(diagview(D2)) == r
65+
@test A * V2 V2 * D2
66+
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
67+
68+
s = 1 - sqrt(eps(real(eltype(T))))
69+
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
70+
D3, V3, ϵ3 = @testinferred eig_trunc(A; trunc)
71+
@test length(diagview(D3)) == r
72+
@test A * V3 V3 * D3
73+
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
74+
75+
# trunctol keeps order, truncrank might not
76+
# test for same subspace
77+
@test V1 * ((V1' * V1) \ (V1' * V2)) V2
78+
@test V2 * ((V2' * V2) \ (V2' * V1)) V1
79+
@test V1 * ((V1' * V1) \ (V1' * V3)) V3
80+
@test V3 * ((V3' * V3) \ (V3' * V1)) V1
81+
82+
# TODO
83+
#=atol = sqrt(eps(real(eltype(T))))
84+
V = randn(rng, T, m, m)
85+
D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
86+
A = V * D * inv(V)
87+
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
88+
D2, V2, ϵ2 = @testinferred eig_trunc(A; alg)
89+
@test diagview(D2) ≈ diagview(D)[1:2]
90+
@test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol
91+
@test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2))
92+
93+
alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
94+
D3, V3, ϵ3 = @testinferred eig_trunc(A; alg)
95+
@test diagview(D3) ≈ diagview(D)[1:2]
96+
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol=#
97+
end
98+
end

0 commit comments

Comments
 (0)