Skip to content

Commit 1b4399f

Browse files
committed
Add tests for SVD
1 parent 8c574ec commit 1b4399f

File tree

1 file changed

+92
-84
lines changed

1 file changed

+92
-84
lines changed

test/cuda/svd.jl

Lines changed: 92 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,34 @@ include(joinpath("..", "utilities.jl"))
1515
k = min(m, n)
1616
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
1717
@testset "algorithm $alg" for alg in algs
18-
minmn = min(m, n)
19-
A = CuArray(randn(rng, T, m, n))
18+
As = m == n ? (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m)))) : (CuArray(randn(rng, T, m, n)),)
19+
for A in As
20+
minmn = min(m, n)
21+
U, S, Vᴴ = svd_compact(A; alg)
22+
@test U isa CuMatrix{T} && size(U) == (m, minmn)
23+
@test S isa Diagonal{real(T), <:CuVector} && size(S) == (minmn, minmn)
24+
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (minmn, n)
25+
@test U * S * Vᴴ A
26+
@test isapproxone(U' * U)
27+
@test isapproxone(Vᴴ * Vᴴ')
28+
@test isposdef(S)
2029

21-
U, S, Vᴴ = svd_compact(A; alg)
22-
@test U isa CuMatrix{T} && size(U) == (m, minmn)
23-
@test S isa Diagonal{real(T), <:CuVector} && size(S) == (minmn, minmn)
24-
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (minmn, n)
25-
@test U * S * Vᴴ A
26-
@test isapproxone(U' * U)
27-
@test isapproxone(Vᴴ * Vᴴ')
28-
@test isposdef(S)
30+
Ac = similar(A)
31+
U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg)
32+
@test U2 === U
33+
@test S2 === S
34+
@test V2ᴴ === Vᴴ
35+
@test U * S * Vᴴ A
36+
@test isapproxone(U' * U)
37+
@test isapproxone(Vᴴ * Vᴴ')
38+
@test isposdef(S)
2939

30-
Ac = similar(A)
31-
U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg)
32-
@test U2 === U
33-
@test S2 === S
34-
@test V2ᴴ === Vᴴ
35-
@test U * S * Vᴴ A
36-
@test isapproxone(U' * U)
37-
@test isapproxone(Vᴴ * Vᴴ')
38-
@test isposdef(S)
39-
40-
Sd = svd_vals(A, alg)
41-
@test CuArray(diagview(S)) Sd
42-
# CuArray is necessary because norm of CuArray view with non-unit step is broken
43-
if alg isa CUSOLVER_QRIteration
44-
@test_warn "invalid keyword arguments for GPU_QRIteration" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad"))
40+
Sd = svd_vals(A, alg)
41+
@test CuArray(diagview(S)) Sd
42+
# CuArray is necessary because norm of CuArray view with non-unit step is broken
43+
if alg isa CUSOLVER_QRIteration
44+
@test_warn "invalid keyword arguments for GPU_QRIteration" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad"))
45+
end
4546
end
4647
end
4748
end
@@ -53,56 +54,61 @@ end
5354
@testset "size ($m, $n)" for n in (37, m, 63)
5455
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
5556
@testset "algorithm $alg" for alg in algs
56-
A = CuArray(randn(rng, T, m, n))
57-
U, S, Vᴴ = svd_full(A; alg)
58-
@test U isa CuMatrix{T} && size(U) == (m, m)
59-
@test S isa CuMatrix{real(T)} && size(S) == (m, n)
60-
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n)
61-
@test U * S * Vᴴ A
62-
@test isapproxone(U' * U)
63-
@test isapproxone(U * U')
64-
@test isapproxone(Vᴴ * Vᴴ')
65-
@test isapproxone(Vᴴ' * Vᴴ)
66-
@test all(isposdef, diagview(S))
57+
As = m == n ? (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m)))) : (CuArray(randn(rng, T, m, n)),)
58+
for A in As
59+
minmn = min(m, n)
60+
U, S, Vᴴ = svd_full(A; alg)
61+
@test U isa CuMatrix{T} && size(U) == (m, m)
62+
@test S isa CuMatrix{real(T)} && size(S) == (m, n)
63+
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n)
64+
@test U * S * Vᴴ A
65+
@test isapproxone(U' * U)
66+
@test isapproxone(U * U')
67+
@test isapproxone(Vᴴ * Vᴴ')
68+
@test isapproxone(Vᴴ' * Vᴴ)
69+
@test all(isposdef, diagview(S))
6770

68-
Ac = similar(A)
69-
U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg)
70-
@test U2 === U
71-
@test S2 === S
72-
@test V2ᴴ === Vᴴ
73-
@test U * S * Vᴴ A
74-
@test isapproxone(U' * U)
75-
@test isapproxone(U * U')
76-
@test isapproxone(Vᴴ * Vᴴ')
77-
@test isapproxone(Vᴴ' * Vᴴ)
78-
@test all(isposdef, diagview(S))
71+
Ac = similar(A)
72+
U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg)
73+
@test U2 === U
74+
@test S2 === S
75+
@test V2ᴴ === Vᴴ
76+
@test U * S * Vᴴ A
77+
@test isapproxone(U' * U)
78+
@test isapproxone(U * U')
79+
@test isapproxone(Vᴴ * Vᴴ')
80+
@test isapproxone(Vᴴ' * Vᴴ)
81+
@test all(isposdef, diagview(S))
7982

80-
minmn = min(m, n)
81-
Sc = similar(A, real(T), minmn)
82-
Sc2 = svd_vals!(copy!(Ac, A), Sc, alg)
83-
@test Sc === Sc2
84-
@test CuArray(diagview(S)) Sc
85-
# CuArray is necessary because norm of CuArray view with non-unit step is broken
86-
if alg isa CUSOLVER_QRIteration
87-
@test_warn "invalid keyword arguments for GPU_QRIteration" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad"))
88-
@test_warn "invalid keyword arguments for GPU_QRIteration" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad"))
83+
Sc = similar(A, real(T), minmn)
84+
Sc2 = svd_vals!(copy!(Ac, A), Sc, alg)
85+
@test Sc === Sc2
86+
@test CuArray(diagview(S)) Sc
87+
# CuArray is necessary because norm of CuArray view with non-unit step is broken
88+
if alg isa CUSOLVER_QRIteration
89+
@test_warn "invalid keyword arguments for GPU_QRIteration" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad"))
90+
if !isa(A, Diagonal)
91+
@test_warn "invalid keyword arguments for GPU_QRIteration" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad"))
92+
end
93+
end
8994
end
9095
end
9196
end
9297
@testset "size (0, 0)" begin
9398
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi())
9499
@testset "algorithm $alg" for alg in algs
95-
A = CuArray(randn(rng, T, 0, 0))
96-
U, S, Vᴴ = svd_full(A; alg)
97-
@test U isa CuMatrix{T} && size(U) == (0, 0)
98-
@test S isa CuMatrix{real(T)} && size(S) == (0, 0)
99-
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (0, 0)
100-
@test U * S * Vᴴ A
101-
@test isapproxone(U' * U)
102-
@test isapproxone(U * U')
103-
@test isapproxone(Vᴴ * Vᴴ')
104-
@test isapproxone(Vᴴ' * Vᴴ)
105-
@test all(isposdef, diagview(S))
100+
for A in (CuArray(randn(rng, T, 0, 0)), Diagonal(CuArray(randn(rng, T, 0))))
101+
U, S, Vᴴ = svd_full(A; alg)
102+
@test U isa CuMatrix{T} && size(U) == (0, 0)
103+
@test S isa CuMatrix{real(T)} && size(S) == (0, 0)
104+
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (0, 0)
105+
@test U * S * Vᴴ A
106+
@test isapproxone(U' * U)
107+
@test isapproxone(U * U')
108+
@test isapproxone(Vᴴ * Vᴴ')
109+
@test isapproxone(Vᴴ' * Vᴴ)
110+
@test all(isposdef, diagview(S))
111+
end
106112
end
107113
end
108114
end
@@ -115,26 +121,28 @@ end
115121
p = min(m, n) - k - 1
116122
algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k = k, p = p, niters = 100))
117123
@testset "algorithm $alg" for alg in algs
118-
hA = randn(rng, T, m, n)
119-
S₀ = svd_vals(hA)
120-
A = CuArray(hA)
121-
minmn = min(m, n)
122-
r = k
124+
hAs = m == n ? (randn(rng, T, m, n), Diagonal(randn(rng, T, m))) : (randn(rng, T, m, n),)
125+
for hA in hAs
126+
S₀ = svd_vals(hA)
127+
A = CuArray(hA)
128+
minmn = min(m, n)
129+
r = k
123130

124-
U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r))
125-
@test length(S1.diag) == r
126-
@test opnorm(A - U1 * S1 * V1ᴴ) S₀[r + 1]
127-
@test norm(A - U1 * S1 * V1ᴴ) ϵ1
131+
U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r))
132+
@test length(S1.diag) == r
133+
@test opnorm(A - U1 * S1 * V1ᴴ) S₀[r + 1]
134+
@test norm(A - U1 * S1 * V1ᴴ) ϵ1
128135

129-
if !(alg isa CUSOLVER_Randomized)
130-
s = 1 + sqrt(eps(real(T)))
131-
trunc2 = trunctol(; atol = s * S₀[r + 1])
136+
if !(alg isa CUSOLVER_Randomized)
137+
s = 1 + sqrt(eps(real(T)))
138+
trunc2 = trunctol(; atol = s * S₀[r + 1])
132139

133-
U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1]))
134-
@test length(S2.diag) == r
135-
@test U1 U2
136-
@test parent(S1) parent(S2)
137-
@test V1ᴴ V2ᴴ
140+
U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1]))
141+
@test length(S2.diag) == r
142+
@test U1 U2
143+
@test parent(S1) parent(S2)
144+
@test V1ᴴ V2ᴴ
145+
end
138146
end
139147
end
140148
end

0 commit comments

Comments
 (0)