Skip to content

Commit e72dca4

Browse files
authored
[WIP] Implement DiagonalAlgorithms for the GPU (#106)
* Implement DiagonalAlgorithm for GPU * Comments and format
1 parent ba2c9ef commit e72dca4

File tree

10 files changed

+527
-194
lines changed

10 files changed

+527
-194
lines changed

src/implementations/svd.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
218218
check_input(svd_full!, A, USVᴴ, alg)
219219
Ad = diagview(A)
220220
U, S, Vᴴ = USVᴴ
221+
if isempty(Ad)
222+
one!(U)
223+
one!(Vᴴ)
224+
return USVᴴ
225+
end
221226
p = sortperm(Ad; by = abs, rev = true)
222227
zero!(U)
223228
zero!(Vᴴ)

test/amd/eigh.jl

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ using LinearAlgebra: LinearAlgebra, Diagonal, I
66
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
77
using AMDGPU
88

9-
@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
9+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
10+
11+
@testset "eigh_full! for T = $T" for T in BLASFloats
1012
rng = StableRNG(123)
1113
m = 54
1214
for alg in (
@@ -32,11 +34,11 @@ using AMDGPU
3234
end
3335
end
3436

35-
#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
37+
#=@testset "eigh_trunc! for T = $T" for T in BLASFloats
3638
rng = StableRNG(123)
3739
m = 54
38-
for alg in (CUSOLVER_QRIteration(),
39-
CUSOLVER_DivideAndConquer(),
40+
for alg in (ROCSOLVER_QRIteration(),
41+
ROCSOLVER_DivideAndConquer(),
4042
)
4143
A = ROCArray(randn(rng, T, m, m))
4244
A = A * A'
@@ -64,18 +66,40 @@ end
6466
end
6567
end
6668
67-
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
68-
(Float32, Float64,
69-
ComplexF32,
70-
ComplexF64)
69+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats
7170
rng = StableRNG(123)
7271
m = 4
7372
V = qr_compact(ROCArray(randn(rng, T, m, m)))[1]
7473
D = Diagonal([0.9, 0.3, 0.1, 0.01])
7574
A = V * D * V'
7675
A = (A + A') / 2
77-
alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2))
76+
alg = TruncatedAlgorithm(ROCSOLVER_QRIteration(), truncrank(2))
7877
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
7978
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
8079
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
8180
end=#
81+
82+
@testset "eigh for Diagonal{$T}" for T in BLASFloats
83+
rng = StableRNG(123)
84+
m = 54
85+
Ad = randn(rng, T, m)
86+
Ad .+= conj.(Ad)
87+
A = Diagonal(ROCArray(Ad))
88+
atol = sqrt(eps(real(T)))
89+
90+
D, V = @constinferred eigh_full(A)
91+
@test D isa Diagonal{real(T)} && size(D) == size(A)
92+
@test V isa Diagonal{T} && size(V) == size(A)
93+
@test A * V V * D
94+
95+
D2 = @constinferred eigh_vals(A)
96+
@test D2 isa AbstractVector{real(T)} && length(D2) == m
97+
@test diagview(D) D2
98+
99+
# TODO partialsortperm
100+
#=A2 = Diagonal(ROCArray(T[0.9, 0.3, 0.1, 0.01]))
101+
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
102+
D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
103+
@test diagview(D2) ≈ diagview(A2)[1:2]
104+
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=#
105+
end

test/amd/lq.jl

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ using Test
44
using TestExtras
55
using StableRNGs
66
using AMDGPU
7+
using LinearAlgebra
78

89
include(joinpath("..", "utilities.jl"))
910

10-
@testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
11+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
12+
13+
@testset "lq_compact! for T = $T" for T in BLASFloats
1114
rng = StableRNG(123)
1215
m = 54
1316
for n in (37, m, 63)
@@ -65,7 +68,7 @@ include(joinpath("..", "utilities.jl"))
6568
end
6669
end
6770

68-
@testset "lq_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
71+
@testset "lq_full! for T = $T" for T in BLASFloats
6972
rng = StableRNG(123)
7073
m = 54
7174
for n in (37, m, 63)
@@ -115,3 +118,46 @@ end
115118
@test_throws ArgumentError lq_full!(copy!(Ac, A), (L, Q); blocksize = 8)
116119
end
117120
end
121+
122+
@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in BLASFloats
123+
rng = StableRNG(123)
124+
atol = eps(real(T))^(3 / 4)
125+
for m in (54, 0)
126+
Ad = ROCArray(randn(rng, T, m))
127+
A = Diagonal(Ad)
128+
129+
# compact
130+
L, Q = @constinferred lq_compact(A)
131+
@test Q isa Diagonal{T} && size(Q) == (m, m)
132+
@test L isa Diagonal{T} && size(L) == (m, m)
133+
@test L * Q A
134+
@test isunitary(Q)
135+
136+
# compact and positive
137+
Lp, Qp = @constinferred lq_compact(A; positive = true)
138+
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
139+
@test Lp isa Diagonal{T} && size(Lp) == (m, m)
140+
@test Lp * Qp A
141+
@test isunitary(Qp)
142+
@test all(isposdef.(diagview(Lp)))
143+
144+
# full
145+
L, Q = @constinferred lq_full(A)
146+
@test Q isa Diagonal{T} && size(Q) == (m, m)
147+
@test L isa Diagonal{T} && size(L) == (m, m)
148+
@test L * Q A
149+
@test isunitary(Q)
150+
151+
# full and positive
152+
Lp, Qp = @constinferred lq_full(A; positive = true)
153+
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
154+
@test Lp isa Diagonal{T} && size(Lp) == (m, m)
155+
@test Lp * Qp A
156+
@test isunitary(Qp)
157+
@test all(isposdef.(diagview(Lp)))
158+
159+
# null
160+
N = @constinferred lq_null(A)
161+
@test N isa AbstractMatrix{T} && size(N) == (0, m)
162+
end
163+
end

test/amd/qr.jl

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ using Test
44
using TestExtras
55
using StableRNGs
66
using AMDGPU
7+
using LinearAlgebra
78

89
include(joinpath("..", "utilities.jl"))
910

10-
eltypes = (Float32, Float64, ComplexF32, ComplexF64)
11+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
1112

12-
@testset "qr_compact! and qr_null! for T = $T" for T in eltypes
13+
@testset "qr_compact! and qr_null! for T = $T" for T in BLASFloats
1314
rng = StableRNG(123)
1415
m = 54
1516
for n in (37, m, 63)
@@ -68,7 +69,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64)
6869
end
6970
end
7071

71-
@testset "qr_full! for T = $T" for T in eltypes
72+
@testset "qr_full! for T = $T" for T in BLASFloats
7273
rng = StableRNG(123)
7374
m = 63
7475
for n in (37, m, 63)
@@ -121,3 +122,46 @@ end
121122
@test_throws ArgumentError qr_full!(copy!(Ac, A), (Q, R); blocksize = 8)
122123
end
123124
end
125+
126+
@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in BLASFloats
127+
rng = StableRNG(123)
128+
atol = eps(real(T))^(3 / 4)
129+
for m in (54, 0)
130+
Ad = ROCArray(randn(rng, T, m))
131+
A = Diagonal(Ad)
132+
133+
# compact
134+
Q, R = @constinferred qr_compact(A)
135+
@test Q isa Diagonal{T} && size(Q) == (m, m)
136+
@test R isa Diagonal{T} && size(R) == (m, m)
137+
@test Q * R A
138+
@test isunitary(Q)
139+
140+
# compact and positive
141+
Qp, Rp = @constinferred qr_compact(A; positive = true)
142+
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
143+
@test Rp isa Diagonal{T} && size(Rp) == (m, m)
144+
@test Qp * Rp A
145+
@test isunitary(Qp)
146+
@test all(isposdef.(diagview(Rp)))
147+
148+
# full
149+
Q, R = @constinferred qr_full(A)
150+
@test Q isa Diagonal{T} && size(Q) == (m, m)
151+
@test R isa Diagonal{T} && size(R) == (m, m)
152+
@test Q * R A
153+
@test isunitary(Q)
154+
155+
# full and positive
156+
Qp, Rp = @constinferred qr_full(A; positive = true)
157+
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
158+
@test Rp isa Diagonal{T} && size(Rp) == (m, m)
159+
@test Qp * Rp A
160+
@test isunitary(Qp)
161+
@test all(isposdef.(diagview(Rp)))
162+
163+
# null
164+
N = @constinferred qr_null(A)
165+
@test N isa AbstractMatrix{T} && size(N) == (m, 0)
166+
end
167+
end

0 commit comments

Comments
 (0)