Skip to content

Commit d154bae

Browse files
committed
dont call project_isometric with invalid alg-input types
1 parent 6bf4854 commit d154bae

File tree

2 files changed

+75
-32
lines changed

2 files changed

+75
-32
lines changed

test/amd/projections.jl

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,46 @@ end
5757
svdalgs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
5858
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
5959
@testset "algorithm $alg" for alg in algs
60-
for A in (ROCArray(randn(rng, T, m, n)), Diagonal(ROCArray(randn(rng, T, m))))
60+
A = ROCArray(randn(rng, T, m, n))
61+
W = project_isometric(A, alg)
62+
@test isisometric(W)
63+
W2 = project_isometric(W, alg)
64+
@test W2 W # stability of the projection
65+
@test W * (W' * A) A
66+
67+
Ac = similar(A)
68+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
69+
@test W2 === W
70+
@test isisometric(W)
71+
72+
# test that W is closer to A then any other isometry
73+
for k in 1:10
74+
δA = ROCArray(randn(rng, T, size(A)...))
6175
W = project_isometric(A, alg)
62-
@test isisometric(W)
63-
W2 = project_isometric(W, alg)
64-
@test W2 W # stability of the projection
65-
@test W * (W' * A) A
76+
W2 = project_isometric(A + δA / 100, alg)
77+
@test norm(A - W2) >= norm(A - W)
78+
end
79+
end
80+
81+
m == n && @testset "DiagonalAlgorithm" begin
82+
A = Diagonal(ROCArray(randn(rng, T, m)))
83+
W = project_isometric(A, alg)
84+
@test isisometric(W)
85+
W2 = project_isometric(W, alg)
86+
@test W2 W # stability of the projection
87+
@test W * (W' * A) A
6688

67-
Ac = similar(A)
68-
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
69-
@test W2 === W
70-
@test isisometric(W)
89+
Ac = similar(A)
90+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
91+
@test W2 === W
92+
@test isisometric(W)
7193

72-
# test that W is closer to A then any other isometry
73-
for k in 1:10
74-
δA = ROCArray(randn(rng, T, size(A)...))
75-
W = project_isometric(A, alg)
76-
W2 = project_isometric(A + δA / 100, alg)
77-
@test norm(A - W2) > norm(A - W)
78-
end
94+
# test that W is closer to A then any other isometry
95+
for k in 1:10
96+
δA = Diagonal(ROCArray(randn(rng, T, m)))
97+
W = project_isometric(A, alg)
98+
W2 = project_isometric(A + δA / 100, alg)
99+
@test norm(A - W2) >= norm(A - W)
79100
end
80101
end
81102
end

test/cuda/projections.jl

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,47 @@ end
5757
svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
5858
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
5959
@testset "algorithm $alg" for alg in algs
60-
for A in (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m))))
60+
A = CuArray(randn(rng, T, m, n))
61+
W = project_isometric(A, alg)
62+
@test isisometric(W)
63+
W2 = project_isometric(W, alg)
64+
@test W2 W # stability of the projection
65+
@test W * (W' * A) A
66+
67+
Ac = similar(A)
68+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
69+
@test W2 === W
70+
@test isisometric(W)
71+
72+
# test that W is closer to A then any other isometry
73+
for k in 1:10
74+
δA = CuArray(randn(rng, T, size(A)...))
6175
W = project_isometric(A, alg)
62-
@test isisometric(W)
63-
W2 = project_isometric(W, alg)
64-
@test W2 W # stability of the projection
65-
@test W * (W' * A) A
76+
W2 = project_isometric(A + δA / 100, alg)
77+
@test norm(A - W2) >= norm(A - W)
78+
end
79+
end
80+
81+
m == n && @testset "DiagonalAlgorithm" begin
82+
A = Diagonal(CuArray(randn(rng, T, m)))
83+
alg = PolarViaSVD(DiagonalAlgorithm())
84+
W = project_isometric(A, alg)
85+
@test isisometric(W)
86+
W2 = project_isometric(W, alg)
87+
@test W2 W # stability of the projection
88+
@test W * (W' * A) A
6689

67-
Ac = similar(A)
68-
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
69-
@test W2 === W
70-
@test isisometric(W)
90+
Ac = similar(A)
91+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
92+
@test W2 === W
93+
@test isisometric(W)
7194

72-
# test that W is closer to A then any other isometry
73-
for k in 1:10
74-
δA = CuArray(randn(rng, T, size(A)...))
75-
W = project_isometric(A, alg)
76-
W2 = project_isometric(A + δA / 100, alg)
77-
@test norm(A - W2) > norm(A - W)
78-
end
95+
# test that W is closer to A then any other isometry
96+
for k in 1:10
97+
δA = Diagonal(CuArray(randn(rng, T, m)))
98+
W = project_isometric(A, alg)
99+
W2 = project_isometric(A + δA / 100, alg)
100+
@test norm(A - W2) >= norm(A - W)
79101
end
80102
end
81103
end

0 commit comments

Comments
 (0)