@@ -8,109 +8,29 @@ using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
88BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
99GenericFloats = (Float16, BigFloat, Complex{BigFloat})
1010
11- @testset " eig_full! for T = $T " for T in BLASFloats
12- rng = StableRNG (123 )
13- m = 54
14- for alg in (LAPACK_Simple (), LAPACK_Expert (), :LAPACK_Simple , LAPACK_Simple)
15- A = randn (rng, T, m, m)
16- Tc = complex (T)
11+ using CUDA, AMDGPU
1712
18- D, V = @constinferred eig_full (A; alg = ($ alg))
19- @test eltype (D) == eltype (V) == Tc
20- @test A * V ≈ V * D
21-
22- alg′ = @constinferred MatrixAlgebraKit. select_algorithm (eig_full!, A, $ alg)
23-
24- Ac = similar (A)
25- D2, V2 = @constinferred eig_full! (copy! (Ac, A), (D, V), alg′)
26- @test D2 === D
27- @test V2 === V
28- @test A * V ≈ V * D
29-
30- Dc = @constinferred eig_vals (A, alg′)
31- @test eltype (Dc) == Tc
32- @test D ≈ Diagonal (Dc)
33- end
34- end
35-
36- @testset " eig_trunc! for T = $T " for T in BLASFloats
37- rng = StableRNG (123 )
38- m = 54
39- for alg in (LAPACK_Simple (), LAPACK_Expert ())
40- A = randn (rng, T, m, m)
41- A *= A' # TODO : deal with eigenvalue ordering etc
42- # eigenvalues are sorted by ascending real component...
43- D₀ = sort! (eig_vals (A); by = abs, rev = true )
44- rmin = findfirst (i -> abs (D₀[end - i]) != abs (D₀[end - i - 1 ]), 1 : (m - 2 ))
45- r = length (D₀) - rmin
46- atol = sqrt (eps (real (T)))
47-
48- D1, V1, ϵ1 = @constinferred eig_trunc (A; alg, trunc = truncrank (r))
49- @test length (diagview (D1)) == r
50- @test A * V1 ≈ V1 * D1
51- @test ϵ1 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
52-
53- s = 1 + sqrt (eps (real (T)))
54- trunc = trunctol (; atol = s * abs (D₀[r + 1 ]))
55- D2, V2, ϵ2 = @constinferred eig_trunc (A; alg, trunc)
56- @test length (diagview (D2)) == r
57- @test A * V2 ≈ V2 * D2
58- @test ϵ2 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
59-
60- s = 1 - sqrt (eps (real (T)))
61- trunc = truncerror (; atol = s * norm (@view (D₀[r: end ]), 1 ), p = 1 )
62- D3, V3, ϵ3 = @constinferred eig_trunc (A; alg, trunc)
63- @test length (diagview (D3)) == r
64- @test A * V3 ≈ V3 * D3
65- @test ϵ3 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
66-
67- # trunctol keeps order, truncrank might not
68- # test for same subspace
69- @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2
70- @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1
71- @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3
72- @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1
13+ BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
14+ GenericFloats = (Float16,) # BigFloat, Complex{BigFloat})
15+
16+ @isdefined (TestSuite) || include (" testsuite/TestSuite.jl" )
17+ using . TestSuite
18+
19+ m = 54
20+ for T in BLASFloats
21+ TestSuite. seed_rng! (123 )
22+ TestSuite. test_eig (T, (m, m))
23+ if CUDA. functional ()
24+ TestSuite. test_eig (CuMatrix{T}, (m, m); test_blocksize = false )
25+ TestSuite. test_eig (Diagonal{T, CuVector{T}}, m; test_blocksize = false )
7326 end
27+ #= not yet supported
28+ if AMDGPU.functional()
29+ TestSuite.test_eig(ROCMatrix{T}, (m, m); test_blocksize = false)
30+ TestSuite.test_eig(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
31+ end=#
7432end
75-
76- @testset " eig_trunc! specify truncation algorithm T = $T " for T in BLASFloats
77- rng = StableRNG (123 )
78- m = 4
79- atol = sqrt (eps (real (T)))
80- V = randn (rng, T, m, m)
81- D = Diagonal (real (T)[0.9 , 0.3 , 0.1 , 0.01 ])
82- A = V * D * inv (V)
83- alg = TruncatedAlgorithm (LAPACK_Simple (), truncrank (2 ))
84- D2, V2, ϵ2 = @constinferred eig_trunc (A; alg)
85- @test diagview (D2) ≈ diagview (D)[1 : 2 ]
86- @test ϵ2 ≈ norm (diagview (D)[3 : 4 ]) atol = atol
87- @test_throws ArgumentError eig_trunc (A; alg, trunc = (; maxrank = 2 ))
88-
89- alg = TruncatedAlgorithm (LAPACK_Simple (), truncerror (; atol = 0.2 , p = 1 ))
90- D3, V3, ϵ3 = @constinferred eig_trunc (A; alg)
91- @test diagview (D3) ≈ diagview (D)[1 : 2 ]
92- @test ϵ3 ≈ norm (diagview (D)[3 : 4 ]) atol = atol
93- end
94-
95- @testset " eig for Diagonal{$T }" for T in (BLASFloats... , GenericFloats... )
96- rng = StableRNG (123 )
97- m = 54
98- Ad = randn (rng, T, m)
99- A = Diagonal (Ad)
100- atol = sqrt (eps (real (T)))
101-
102- D, V = @constinferred eig_full (A)
103- @test D isa Diagonal{T} && size (D) == size (A)
104- @test V isa Diagonal{T} && size (V) == size (A)
105- @test A * V ≈ V * D
106-
107- D2 = @constinferred eig_vals (A)
108- @test D2 isa AbstractVector{T} && length (D2) == m
109- @test diagview (D) ≈ D2
110-
111- A2 = Diagonal (T[0.9 , 0.3 , 0.1 , 0.01 ])
112- alg = TruncatedAlgorithm (DiagonalAlgorithm (), truncrank (2 ))
113- D2, V2, ϵ2 = @constinferred eig_trunc (A2; alg)
114- @test diagview (D2) ≈ diagview (A2)[1 : 2 ]
115- @test ϵ2 ≈ norm (diagview (A2)[3 : 4 ]) atol = atol
33+ for T in (BLASFloats... , GenericFloats... )
34+ AT = Diagonal{T, Vector{T}}
35+ TestSuite. test_eig (AT, m; test_blocksize = false )
11636end
0 commit comments