|
57 | 57 | svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi()) |
58 | 58 | algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO |
59 | 59 | @testset "algorithm $alg" for alg in algs |
60 | | - for A in (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m)))) |
| 60 | + A = CuArray(randn(rng, T, m, n)) |
| 61 | + W = project_isometric(A, alg) |
| 62 | + @test isisometric(W) |
| 63 | + W2 = project_isometric(W, alg) |
| 64 | + @test W2 ≈ W # stability of the projection |
| 65 | + @test W * (W' * A) ≈ A |
| 66 | + |
| 67 | + Ac = similar(A) |
| 68 | + W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) |
| 69 | + @test W2 === W |
| 70 | + @test isisometric(W) |
| 71 | + |
| 72 | + # test that W is closer to A then any other isometry |
| 73 | + for k in 1:10 |
| 74 | + δA = CuArray(randn(rng, T, size(A)...)) |
61 | 75 | W = project_isometric(A, alg) |
62 | | - @test isisometric(W) |
63 | | - W2 = project_isometric(W, alg) |
64 | | - @test W2 ≈ W # stability of the projection |
65 | | - @test W * (W' * A) ≈ A |
| 76 | + W2 = project_isometric(A + δA / 100, alg) |
| 77 | + @test norm(A - W2) >= norm(A - W) |
| 78 | + end |
| 79 | + end |
| 80 | + |
| 81 | + m == n && @testset "DiagonalAlgorithm" begin |
| 82 | + A = Diagonal(CuArray(randn(rng, T, m))) |
| 83 | + alg = PolarViaSVD(DiagonalAlgorithm()) |
| 84 | + W = project_isometric(A, alg) |
| 85 | + @test isisometric(W) |
| 86 | + W2 = project_isometric(W, alg) |
| 87 | + @test W2 ≈ W # stability of the projection |
| 88 | + @test W * (W' * A) ≈ A |
66 | 89 |
|
67 | | - Ac = similar(A) |
68 | | - W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) |
69 | | - @test W2 === W |
70 | | - @test isisometric(W) |
| 90 | + Ac = similar(A) |
| 91 | + W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) |
| 92 | + @test W2 === W |
| 93 | + @test isisometric(W) |
71 | 94 |
|
72 | | - # test that W is closer to A then any other isometry |
73 | | - for k in 1:10 |
74 | | - δA = CuArray(randn(rng, T, size(A)...)) |
75 | | - W = project_isometric(A, alg) |
76 | | - W2 = project_isometric(A + δA / 100, alg) |
77 | | - @test norm(A - W2) > norm(A - W) |
78 | | - end |
| 95 | + # test that W is closer to A then any other isometry |
| 96 | + for k in 1:10 |
| 97 | + δA = Diagonal(CuArray(randn(rng, T, m))) |
| 98 | + W = project_isometric(A, alg) |
| 99 | + W2 = project_isometric(A + δA / 100, alg) |
| 100 | + @test norm(A - W2) >= norm(A - W) |
79 | 101 | end |
80 | 102 | end |
81 | 103 | end |
|
0 commit comments