@@ -3,124 +3,28 @@ using Test
33using TestExtras
44using StableRNGs
55using LinearAlgebra: LinearAlgebra, Diagonal, I
6- using MatrixAlgebraKit : TruncatedAlgorithm, diagview, norm
6+ using CUDA, AMDGPU
77
88BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
99GenericFloats = (Float16, BigFloat, Complex{BigFloat})
1010
11- @testset " eigh_full! for T = $T " for T in BLASFloats
12- rng = StableRNG (123 )
13- m = 54
14- for alg in (
15- LAPACK_MultipleRelativelyRobustRepresentations (),
16- LAPACK_DivideAndConquer (),
17- LAPACK_QRIteration (),
18- LAPACK_Bisection (),
19- )
20- A = randn (rng, T, m, m)
21- A = (A + A' ) / 2
11+ @isdefined (TestSuite) || include (" testsuite/TestSuite.jl" )
12+ using . TestSuite
2213
23- D, V = @constinferred eigh_full (A; alg)
24- @test A * V ≈ V * D
25- @test isunitary (V)
26- @test all (isreal, D)
27-
28- D2, V2 = eigh_full! (copy (A), (D, V), alg)
29- @test D2 === D
30- @test V2 === V
31-
32- D3 = @constinferred eigh_vals (A, alg)
33- @test D ≈ Diagonal (D3)
14+ m = 54
15+ for T in BLASFloats
16+ TestSuite. seed_rng! (123 )
17+ TestSuite. test_eigh (T, (m, m))
18+ if CUDA. functional ()
19+ TestSuite. test_eigh (CuMatrix{T}, (m, m); test_blocksize = false )
20+ TestSuite. test_eigh (Diagonal{T, CuVector{T}}, m; test_blocksize = false )
3421 end
35- end
36-
37- @testset " eigh_trunc! for T = $T " for T in BLASFloats
38- rng = StableRNG (123 )
39- m = 54
40- for alg in (
41- LAPACK_QRIteration (),
42- LAPACK_Bisection (),
43- LAPACK_DivideAndConquer (),
44- LAPACK_MultipleRelativelyRobustRepresentations (),
45- )
46- A = randn (rng, T, m, m)
47- A = A * A'
48- A = (A + A' ) / 2
49- Ac = similar (A)
50- D₀ = reverse (eigh_vals (A))
51- r = m - 2
52- s = 1 + sqrt (eps (real (T)))
53- atol = sqrt (eps (real (T)))
54-
55- D1, V1, ϵ1 = @constinferred eigh_trunc (A; alg, trunc = truncrank (r))
56- @test length (diagview (D1)) == r
57- @test isisometric (V1)
58- @test A * V1 ≈ V1 * D1
59- @test LinearAlgebra. opnorm (A - V1 * D1 * V1' ) ≈ D₀[r + 1 ]
60- @test ϵ1 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
61-
62- trunc = trunctol (; atol = s * D₀[r + 1 ])
63- D2, V2, ϵ2 = @constinferred eigh_trunc (A; alg, trunc)
64- @test length (diagview (D2)) == r
65- @test isisometric (V2)
66- @test A * V2 ≈ V2 * D2
67- @test ϵ2 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
68-
69- s = 1 - sqrt (eps (real (T)))
70- trunc = truncerror (; atol = s * norm (@view (D₀[r: end ]), 1 ), p = 1 )
71- D3, V3, ϵ3 = @constinferred eigh_trunc (A; alg, trunc)
72- @test length (diagview (D3)) == r
73- @test A * V3 ≈ V3 * D3
74- @test ϵ3 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
75-
76- # test for same subspace
77- @test V1 * (V1' * V2) ≈ V2
78- @test V2 * (V2' * V1) ≈ V1
79- @test V1 * (V1' * V3) ≈ V3
80- @test V3 * (V3' * V1) ≈ V1
22+ if AMDGPU. functional ()
23+ TestSuite. test_eigh (ROCMatrix{T}, (m, m); test_blocksize = false )
24+ TestSuite. test_eigh (Diagonal{T, ROCVector{T}}, m; test_blocksize = false )
8125 end
8226end
83-
84- @testset " eigh_trunc! specify truncation algorithm T = $T " for T in BLASFloats
85- rng = StableRNG (123 )
86- m = 4
87- atol = sqrt (eps (real (T)))
88- V = qr_compact (randn (rng, T, m, m))[1 ]
89- D = Diagonal (real (T)[0.9 , 0.3 , 0.1 , 0.01 ])
90- A = V * D * V'
91- A = (A + A' ) / 2
92- alg = TruncatedAlgorithm (LAPACK_QRIteration (), truncrank (2 ))
93- D2, V2, ϵ2 = @constinferred eigh_trunc (A; alg)
94- @test diagview (D2) ≈ diagview (D)[1 : 2 ]
95- @test_throws ArgumentError eigh_trunc (A; alg, trunc = (; maxrank = 2 ))
96- @test ϵ2 ≈ norm (diagview (D)[3 : 4 ]) atol = atol
97-
98- alg = TruncatedAlgorithm (LAPACK_QRIteration (), truncerror (; atol = 0.2 ))
99- D3, V3, ϵ3 = @constinferred eigh_trunc (A; alg)
100- @test diagview (D3) ≈ diagview (D)[1 : 2 ]
101- @test ϵ3 ≈ norm (diagview (D)[3 : 4 ]) atol = atol
102- end
103-
104- @testset " eigh for Diagonal{$T }" for T in (BLASFloats... , GenericFloats... )
105- rng = StableRNG (123 )
106- m = 54
107- Ad = randn (rng, T, m)
108- Ad .+ = conj .(Ad)
109- A = Diagonal (Ad)
110- atol = sqrt (eps (real (T)))
111-
112- D, V = @constinferred eigh_full (A)
113- @test D isa Diagonal{real (T)} && size (D) == size (A)
114- @test V isa Diagonal{T} && size (V) == size (A)
115- @test A * V ≈ V * D
116-
117- D2 = @constinferred eigh_vals (A)
118- @test D2 isa AbstractVector{real (T)} && length (D2) == m
119- @test diagview (D) ≈ D2
120-
121- A2 = Diagonal (T[0.9 , 0.3 , 0.1 , 0.01 ])
122- alg = TruncatedAlgorithm (DiagonalAlgorithm (), truncrank (2 ))
123- D2, V2, ϵ2 = @constinferred eigh_trunc (A2; alg)
124- @test diagview (D2) ≈ diagview (A2)[1 : 2 ]
125- @test ϵ2 ≈ norm (diagview (A2)[3 : 4 ]) atol = atol
27+ for T in (BLASFloats... , GenericFloats... )
28+ AT = Diagonal{T, Vector{T}}
29+ TestSuite. test_eigh (AT, m; test_blocksize = false )
12630end
0 commit comments