Skip to content

Commit 1972a6b

Browse files
committed
More fixes, some stuff broken
1 parent 05a4994 commit 1972a6b

File tree

6 files changed

+195
-149
lines changed

6 files changed

+195
-149
lines changed

test/eigh.jl

Lines changed: 16 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -3,124 +3,28 @@ using Test
33
using TestExtras
44
using StableRNGs
55
using LinearAlgebra: LinearAlgebra, Diagonal, I
6-
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
6+
using CUDA, AMDGPU
77

88
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
99
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
1010

11-
@testset "eigh_full! for T = $T" for T in BLASFloats
12-
rng = StableRNG(123)
13-
m = 54
14-
for alg in (
15-
LAPACK_MultipleRelativelyRobustRepresentations(),
16-
LAPACK_DivideAndConquer(),
17-
LAPACK_QRIteration(),
18-
LAPACK_Bisection(),
19-
)
20-
A = randn(rng, T, m, m)
21-
A = (A + A') / 2
11+
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
12+
using .TestSuite
2213

23-
D, V = @constinferred eigh_full(A; alg)
24-
@test A * V V * D
25-
@test isunitary(V)
26-
@test all(isreal, D)
27-
28-
D2, V2 = eigh_full!(copy(A), (D, V), alg)
29-
@test D2 === D
30-
@test V2 === V
31-
32-
D3 = @constinferred eigh_vals(A, alg)
33-
@test D Diagonal(D3)
14+
m = 54
15+
for T in BLASFloats
16+
TestSuite.seed_rng!(123)
17+
TestSuite.test_eigh(T, (m, m))
18+
if CUDA.functional()
19+
TestSuite.test_eigh(CuMatrix{T}, (m, m); test_blocksize = false)
20+
TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m; test_blocksize = false)
3421
end
35-
end
36-
37-
@testset "eigh_trunc! for T = $T" for T in BLASFloats
38-
rng = StableRNG(123)
39-
m = 54
40-
for alg in (
41-
LAPACK_QRIteration(),
42-
LAPACK_Bisection(),
43-
LAPACK_DivideAndConquer(),
44-
LAPACK_MultipleRelativelyRobustRepresentations(),
45-
)
46-
A = randn(rng, T, m, m)
47-
A = A * A'
48-
A = (A + A') / 2
49-
Ac = similar(A)
50-
D₀ = reverse(eigh_vals(A))
51-
r = m - 2
52-
s = 1 + sqrt(eps(real(T)))
53-
atol = sqrt(eps(real(T)))
54-
55-
D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r))
56-
@test length(diagview(D1)) == r
57-
@test isisometric(V1)
58-
@test A * V1 V1 * D1
59-
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') D₀[r + 1]
60-
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
61-
62-
trunc = trunctol(; atol = s * D₀[r + 1])
63-
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
64-
@test length(diagview(D2)) == r
65-
@test isisometric(V2)
66-
@test A * V2 V2 * D2
67-
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
68-
69-
s = 1 - sqrt(eps(real(T)))
70-
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
71-
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc)
72-
@test length(diagview(D3)) == r
73-
@test A * V3 V3 * D3
74-
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
75-
76-
# test for same subspace
77-
@test V1 * (V1' * V2) V2
78-
@test V2 * (V2' * V1) V1
79-
@test V1 * (V1' * V3) V3
80-
@test V3 * (V3' * V1) V1
22+
if AMDGPU.functional()
23+
TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_blocksize = false)
24+
TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
8125
end
8226
end
83-
84-
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats
85-
rng = StableRNG(123)
86-
m = 4
87-
atol = sqrt(eps(real(T)))
88-
V = qr_compact(randn(rng, T, m, m))[1]
89-
D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
90-
A = V * D * V'
91-
A = (A + A') / 2
92-
alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2))
93-
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
94-
@test diagview(D2) diagview(D)[1:2]
95-
@test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2))
96-
@test ϵ2 norm(diagview(D)[3:4]) atol = atol
97-
98-
alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2))
99-
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg)
100-
@test diagview(D3) diagview(D)[1:2]
101-
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
102-
end
103-
104-
@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
105-
rng = StableRNG(123)
106-
m = 54
107-
Ad = randn(rng, T, m)
108-
Ad .+= conj.(Ad)
109-
A = Diagonal(Ad)
110-
atol = sqrt(eps(real(T)))
111-
112-
D, V = @constinferred eigh_full(A)
113-
@test D isa Diagonal{real(T)} && size(D) == size(A)
114-
@test V isa Diagonal{T} && size(V) == size(A)
115-
@test A * V V * D
116-
117-
D2 = @constinferred eigh_vals(A)
118-
@test D2 isa AbstractVector{real(T)} && length(D2) == m
119-
@test diagview(D) D2
120-
121-
A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
122-
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
123-
D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
124-
@test diagview(D2) diagview(A2)[1:2]
125-
@test ϵ2 norm(diagview(A2)[3:4]) atol = atol
27+
for T in (BLASFloats..., GenericFloats...)
28+
AT = Diagonal{T, Vector{T}}
29+
TestSuite.test_eigh(AT, m; test_blocksize = false)
12630
end

test/linearmap.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module LinearMaps
33
export LinearMap
44

55
using MatrixAlgebraKit
6-
using MatrixAlgebraKit: AbstractAlgorithm
6+
using MatrixAlgebraKit: AbstractAlgorithm, DiagonalAlgorithm
77
import MatrixAlgebraKit as MAK
88

99
using LinearAlgebra: LinearAlgebra, lmul!, rmul!
@@ -32,6 +32,12 @@ module LinearMaps
3232
LinearMap.(MAK.initialize_output($f!, parent(A), alg))
3333
@eval MAK.$f!(A::LinearMap, F, alg::AbstractAlgorithm) =
3434
LinearMap.(MAK.$f!(parent(A), parent.(F), alg))
35+
@eval MAK.check_input(::typeof($f!), A::LinearMap, F, alg::DiagonalAlgorithm) =
36+
MAK.check_input($f!, parent(A), parent.(F), alg)
37+
@eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::DiagonalAlgorithm) =
38+
LinearMap.(MAK.initialize_output($f!, parent(A), alg))
39+
@eval MAK.$f!(A::LinearMap, F, alg::DiagonalAlgorithm) =
40+
LinearMap.(MAK.$f!(parent(A), parent.(F), alg))
3541
end
3642

3743
for f in (:qr, :lq, :svd)

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,6 @@ include("lq.jl")
6868
include("polar.jl")
6969
include("orthnull.jl")
7070
include("projections.jl")
71+
include("eigh.jl")
7172

7273
end

test/testsuite/eigh.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using TestExtras
2+
3+
function test_eigh(T::Type, sz; kwargs...)
4+
summary_str = testargs_summary(T, sz)
5+
return @testset "eigh $summary_str" begin
6+
test_eigh_full(T, sz; kwargs...)
7+
if eltype(T) <: Union{Float16, ComplexF16, Float32, Float64, ComplexF32, ComplexF64}
8+
test_eigh_trunc(T, sz; kwargs...)
9+
end
10+
end
11+
end
12+
13+
function test_eigh_full(
14+
T::Type, sz;
15+
test_blocksize = true,
16+
atol::Real = 0, rtol::Real = precision(T),
17+
kwargs...
18+
)
19+
summary_str = testargs_summary(T, sz)
20+
return @testset "eigh_full! $summary_str" begin
21+
A = instantiate_matrix(T, sz)
22+
A = (A + A') / 2
23+
Ac = deepcopy(A)
24+
25+
D, V = @testinferred eigh_full(A)
26+
@test A * V V * D
27+
@test isunitary(V)
28+
@test all(isreal, D)
29+
30+
D2, V2 = eigh_full!(copy(A), (D, V))
31+
@test D2 === D
32+
@test V2 === V
33+
34+
D3 = @testinferred eigh_vals(A)
35+
@test D Diagonal(D3)
36+
end
37+
end
38+
39+
function test_eigh_trunc(
40+
T::Type, sz;
41+
test_blocksize = true,
42+
atol::Real = 0, rtol::Real = precision(T),
43+
kwargs...
44+
)
45+
summary_str = testargs_summary(T, sz)
46+
return @testset "eigh_trunc! $summary_str" begin
47+
A = instantiate_matrix(T, sz)
48+
A = A * A'
49+
A = (A + A') / 2
50+
Ac = deepcopy(A)
51+
52+
m = size(A, 1)
53+
D₀ = reverse(eigh_vals(A))
54+
r = m - 2
55+
s = 1 + sqrt(eps(real(eltype(T))))
56+
atol = sqrt(eps(real(eltype(T))))
57+
local V1, V2, V3
58+
@testset "truncrank" begin
59+
D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r))
60+
@test length(diagview(D1)) == r
61+
@test isisometric(V1)
62+
@test A * V1 V1 * D1
63+
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') D₀[r + 1]
64+
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
65+
end
66+
@testset "trunctol" begin
67+
trunc = trunctol(; atol = s * D₀[r + 1])
68+
D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc)
69+
@test length(diagview(D2)) == r
70+
@test isisometric(V2)
71+
@test A * V2 V2 * D2
72+
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
73+
end
74+
@testset "truncerror" begin
75+
s = 1 - sqrt(eps(real(eltype(T))))
76+
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
77+
D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc)
78+
@test length(diagview(D3)) == r
79+
@test A * V3 V3 * D3
80+
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
81+
end
82+
83+
# test for same subspace
84+
@test V1 * (V1' * V2) V2
85+
@test V2 * (V2' * V1) V1
86+
@test V1 * (V1' * V3) V3
87+
@test V3 * (V3' * V1) V1
88+
89+
# TODO
90+
#=
91+
@testset "specify truncation algorithm" begin
92+
atol = sqrt(eps(real(eltype(T))))
93+
V = qr_compact(instantiate_matrix(T, sz))[1]
94+
D = Diagonal(real(eltype(T))[0.9, 0.3, 0.1, 0.01])
95+
A = V * D * V'
96+
A = (A + A') / 2
97+
alg = TruncatedAlgorithm(MatrixAlgebraKit.default_qr_algorithm(A), truncrank(2))
98+
D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg)
99+
@test diagview(D2) ≈ diagview(D)[1:2]
100+
@test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2))
101+
@test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol
102+
103+
alg = TruncatedAlgorithm(MatrixAlgebraKit.default_qr_algorithm(A), truncerror(; atol = 0.2))
104+
D3, V3, ϵ3 = @testinferred eigh_trunc(A; alg)
105+
@test diagview(D3) ≈ diagview(D)[1:2]
106+
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol
107+
end=#
108+
end
109+
end

0 commit comments

Comments
 (0)