@@ -6,7 +6,9 @@ using LinearAlgebra: LinearAlgebra, Diagonal, I
66using MatrixAlgebraKit: TruncatedAlgorithm, diagview
77using AMDGPU
88
9- @testset " eigh_full! for T = $T " for T in (Float32, Float64, ComplexF32, ComplexF64)
9+ BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
10+
11+ @testset " eigh_full! for T = $T " for T in BLASFloats
1012 rng = StableRNG (123 )
1113 m = 54
1214 for alg in (
@@ -32,11 +34,11 @@ using AMDGPU
3234 end
3335end
3436
35- #= @testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
37+ #= @testset "eigh_trunc! for T = $T" for T in BLASFloats
3638 rng = StableRNG(123)
3739 m = 54
38- for alg in (CUSOLVER_QRIteration (),
39- CUSOLVER_DivideAndConquer (),
40+ for alg in (ROCSOLVER_QRIteration (),
41+ ROCSOLVER_DivideAndConquer (),
4042 )
4143 A = ROCArray(randn(rng, T, m, m))
4244 A = A * A'
6466 end
6567end
6668
67- @testset "eigh_trunc! specify truncation algorithm T = $T" for T in
68- (Float32, Float64,
69- ComplexF32,
70- ComplexF64)
69+ @testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats
7170 rng = StableRNG(123)
7271 m = 4
7372 V = qr_compact(ROCArray(randn(rng, T, m, m)))[1]
7473 D = Diagonal([0.9, 0.3, 0.1, 0.01])
7574 A = V * D * V'
7675 A = (A + A') / 2
77- alg = TruncatedAlgorithm(CUSOLVER_QRIteration (), truncrank(2))
76+ alg = TruncatedAlgorithm(ROCSOLVER_QRIteration (), truncrank(2))
7877 D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
7978 @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
8079 @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
8180end=#
81+
82+ @testset " eigh for Diagonal{$T }" for T in BLASFloats
83+ rng = StableRNG (123 )
84+ m = 54
85+ Ad = randn (rng, T, m)
86+ Ad .+ = conj .(Ad)
87+ A = Diagonal (ROCArray (Ad))
88+ atol = sqrt (eps (real (T)))
89+
90+ D, V = @constinferred eigh_full (A)
91+ @test D isa Diagonal{real (T)} && size (D) == size (A)
92+ @test V isa Diagonal{T} && size (V) == size (A)
93+ @test A * V ≈ V * D
94+
95+ D2 = @constinferred eigh_vals (A)
96+ @test D2 isa AbstractVector{real (T)} && length (D2) == m
97+ @test diagview (D) ≈ D2
98+
99+ # TODO partialsortperm
100+ #= A2 = Diagonal(ROCArray(T[0.9, 0.3, 0.1, 0.01]))
101+ alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
102+ D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
103+ @test diagview(D2) ≈ diagview(A2)[1:2]
104+ @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=#
105+ end
0 commit comments