Skip to content

Commit e190dda

Browse files
committed
Updates, more working testsuite
1 parent 6902e53 commit e190dda

File tree

5 files changed

+51
-31
lines changed

5 files changed

+51
-31
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Random = "1"
3737
SafeTestsets = "0.1"
3838
StableRNGs = "1"
3939
Test = "1"
40-
TestExtras = "0.2,0.3"
40+
TestExtras = "0.3.2"
4141
Zygote = "0.7"
4242
julia = "1.10"
4343

@@ -56,4 +56,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5656
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5757

5858
[targets]
59-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake", "Random"]
59+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Random", "Mooncake"]

test/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using MatrixAlgebraKit
22
using Test
33
using StableRNGs
4-
using LinearAlgebra: LinearAlgebra, I, isposdef
4+
using LinearAlgebra: LinearAlgebra, I, isposdef, Diagonal
5+
using CUDA, AMDGPU
56

67
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
78
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
@@ -26,4 +27,3 @@ for T in (BLASFloats..., GenericFloats...)
2627
AT = Diagonal{T, Vector{T}}
2728
TestSuite.test_polar(AT, m; test_pivoted = false, test_blocksize = false)
2829
end
29-

test/projections.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Test
33
using TestExtras
44
using StableRNGs
55
using LinearAlgebra: LinearAlgebra, Diagonal, norm, normalize!
6+
using CUDA, AMDGPU
67

78
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
89
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
@@ -11,15 +12,15 @@ GenericFloats = (Float16, BigFloat, Complex{BigFloat})
1112
using .TestSuite
1213

1314
m = 54
14-
for T in BLASFloats, n in (37, m, 63)
15+
for T in BLASFloats
1516
TestSuite.seed_rng!(123)
16-
TestSuite.test_projections(T, (m, n))
17+
TestSuite.test_projections(T, (m, m))
1718
if CUDA.functional()
18-
TestSuite.test_projections(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
19+
TestSuite.test_projections(CuMatrix{T}, (m, m); test_pivoted = false, test_blocksize = false)
1920
TestSuite.test_projections(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false)
2021
end
2122
if AMDGPU.functional()
22-
TestSuite.test_projections(ROCMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
23+
TestSuite.test_projections(ROCMatrix{T}, (m, m); test_pivoted = false, test_blocksize = false)
2324
TestSuite.test_projections(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false)
2425
end
2526
end

test/testsuite/polar.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
using TestExtras
2+
using LinearAlgebra: isposdef
23

34
function test_polar(T::Type, sz; kwargs...)
45
summary_str = testargs_summary(T, sz)
56
return @testset "polar $summary_str" begin
6-
test_left_polar(T, sz; kwargs...)
7-
test_right_polar(T, sz; kwargs...)
7+
(length(sz) == 1 || sz[1] sz[2]) && test_left_polar(T, sz; kwargs...)
8+
(length(sz) == 1 || sz[2] sz[1]) && test_right_polar(T, sz; kwargs...)
89
end
910
end
1011

@@ -15,7 +16,12 @@ function test_left_polar(
1516
)
1617
summary_str = testargs_summary(T, sz)
1718
return @testset "left_polar! $summary_str" begin
18-
algs = (PolarViaSVD(), PolarNewton())
19+
A = instantiate_matrix(T, sz)
20+
algs = if T <: Diagonal
21+
(PolarNewton(),)
22+
else
23+
(PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)), PolarNewton())
24+
end
1925
@testset "algorithm $alg" for alg in algs
2026
A = instantiate_matrix(T, sz)
2127
Ac = deepcopy(A)
@@ -26,15 +32,15 @@ function test_left_polar(
2632
@test isisometric(W)
2733
@test isposdef(P)
2834

29-
W2, P2 = @constinferred left_polar!(Ac, (W, P), alg)
35+
W2, P2 = @testinferred left_polar!(Ac, (W, P), alg)
3036
@test W2 === W
3137
@test P2 === P
3238
@test W * P A
3339
@test isisometric(W)
3440
@test isposdef(P)
3541

3642
noP = similar(P, (0, 0))
37-
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, noP), alg)
43+
W2, P2 = @testinferred left_polar!(copy!(Ac, A), (W, noP), alg)
3844
@test P2 === noP
3945
@test W2 === W
4046
@test isisometric(W)
@@ -52,26 +58,31 @@ function test_right_polar(
5258
)
5359
summary_str = testargs_summary(T, sz)
5460
return @testset "right_polar! $summary_str" begin
55-
algs = (PolarViaSVD(), PolarNewton())
61+
A = instantiate_matrix(T, sz)
62+
algs = if T <: Diagonal
63+
(PolarNewton(),)
64+
else
65+
(PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)), PolarNewton())
66+
end
5667
@testset "algorithm $alg" for alg in algs
5768
A = instantiate_matrix(T, sz)
5869
Ac = deepcopy(A)
5970
P, Wᴴ = right_polar(A; alg)
6071
@test eltype(Wᴴ) == eltype(A) && size(Wᴴ) == (size(A, 1), size(A, 2))
61-
@test eltype(P) == eltype(A) && size(P) == (size(A, 1), size(A, 1))
72+
@test eltype(P) == eltype(A) && size(P) == (size(A, 1), size(A, 1))
6273
@test P * Wᴴ A
6374
@test isisometric(Wᴴ; side = :right)
6475
@test isposdef(P)
6576

66-
P2, Wᴴ2 = @constinferred right_polar!(Ac, (P, Wᴴ), alg)
77+
P2, Wᴴ2 = @testinferred right_polar!(Ac, (P, Wᴴ), alg)
6778
@test P2 === P
6879
@test Wᴴ2 === Wᴴ
6980
@test P * Wᴴ A
7081
@test isisometric(Wᴴ; side = :right)
7182
@test isposdef(P)
7283

7384
noP = similar(P, (0, 0))
74-
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)
85+
P2, Wᴴ2 = @testinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)
7586
@test P2 === noP
7687
@test Wᴴ2 === Wᴴ
7788
@test isisometric(Wᴴ; side = :right)

test/testsuite/projections.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using TestExtras
2+
using LinearAlgebra: Diagonal, normalize!
23

34
function test_projections(T::Type, sz; kwargs...)
45
summary_str = testargs_summary(T, sz)
@@ -16,7 +17,7 @@ function test_project_antihermitian(
1617
)
1718
summary_str = testargs_summary(T, sz)
1819
return @testset "project_antihermitian! $summary_str" begin
19-
noisefactor = eps(real(T))^(3 / 4)
20+
noisefactor = eps(real(eltype(T)))^(3 / 4)
2021
algs = (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
2122
@testset "algorithm $alg" for alg in algs
2223
A = instantiate_matrix(T, sz)
@@ -30,7 +31,8 @@ function test_project_antihermitian(
3031
@test A == Ac
3132
Ba_approx = Ba + noisefactor * Ah
3233
@test !isantihermitian(Ba_approx)
33-
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
34+
# this is never anti-hermitian for real Diagonal: |A - A'| == 0
35+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0
3436

3537
copy!(Ac, A)
3638
Ba = project_antihermitian!(Ac, alg)
@@ -40,7 +42,7 @@ function test_project_antihermitian(
4042
end
4143

4244
# test approximate error calculation
43-
A = normalize!(randn(rng, T, m, m))
45+
A = normalize!(randn(rng, eltype(T), size(A)...))
4446
Ah = project_hermitian(A)
4547
Aa = project_antihermitian(A)
4648

@@ -63,7 +65,7 @@ function test_project_hermitian(
6365
)
6466
summary_str = testargs_summary(T, sz)
6567
return @testset "project_hermitian! $summary_str" begin
66-
noisefactor = eps(real(T))^(3 / 4)
68+
noisefactor = eps(real(eltype(T)))^(3 / 4)
6769
algs = (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
6870
@testset "algorithm $alg" for alg in algs
6971
A = instantiate_matrix(T, sz)
@@ -76,7 +78,8 @@ function test_project_hermitian(
7678
@test Bh Ah
7779
@test A == Ac
7880
Bh_approx = Bh + noisefactor * Aa
79-
@test !ishermitian(Bh_approx)
81+
# this is still hermitian for real Diagonal: |A - A'| == 0
82+
@test !ishermitian(Bh_approx) || norm(Aa) == 0
8083
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
8184

8285
Bh = project_hermitian!(Ac, alg)
@@ -86,7 +89,7 @@ function test_project_hermitian(
8689
end
8790

8891
# test approximate error calculation
89-
A = normalize!(randn(rng, T, m, m))
92+
A = normalize!(randn(rng, eltype(T), size(A)...))
9093
Ah = project_hermitian(A)
9194
Aa = project_antihermitian(A)
9295

@@ -109,24 +112,29 @@ function test_project_isometric(
109112
)
110113
summary_str = testargs_summary(T, sz)
111114
return @testset "project_isometric! $summary_str" begin
112-
algs = (PolarViaSVD(), PolarNewton())
115+
A = instantiate_matrix(T, sz)
116+
algs = if T <: Diagonal
117+
(PolarNewton(),)
118+
else
119+
(PolarViaSVD(MatrixAlgebraKit.default_svd_algorithm(A)), PolarNewton())
120+
end
113121
@testset "algorithm $alg" for alg in algs
114-
A = instantiate_matrix(T, sz)
115-
Ac = deepcopy(A)
116-
k = min(size(A)...)
117-
W = project_isometric(A, alg)
122+
A = instantiate_matrix(T, sz)
123+
Ac = deepcopy(A)
124+
k = min(size(A)...)
125+
W = project_isometric(A, alg)
118126
@test isisometric(W)
119-
W2 = project_isometric(W, alg)
127+
W2 = project_isometric(W, alg)
120128
@test W2 W # stability of the projection
121129
@test W * (W' * A) A
122130

123-
W2 = @constinferred project_isometric!(Ac, W, alg)
131+
W2 = @testinferred project_isometric!(Ac, W, alg)
124132
@test W2 === W
125133
@test isisometric(W)
126134

127135
# test that W is closer to A then any other isometry
128136
for k in 1:10
129-
δA = randn(rng, T, m, n)
137+
δA = randn(rng, eltype(T), size(A)...)
130138
W = project_isometric(A, alg)
131139
W2 = project_isometric(A + δA / 100, alg)
132140
@test norm(A - W2) > norm(A - W)

0 commit comments

Comments
 (0)