Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
check_input(svd_full!, A, USVᴴ, alg)
Ad = diagview(A)
U, S, Vᴴ = USVᴴ
if isempty(Ad)
one!(U)
one!(Vᴴ)
return USVᴴ
end
p = sortperm(Ad; by = abs, rev = true)
zero!(U)
zero!(Vᴴ)
Expand Down
42 changes: 33 additions & 9 deletions test/amd/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ using LinearAlgebra: LinearAlgebra, Diagonal, I
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
using AMDGPU

@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)

@testset "eigh_full! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 54
for alg in (
Expand All @@ -32,11 +34,11 @@ using AMDGPU
end
end

#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
#=@testset "eigh_trunc! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 54
for alg in (CUSOLVER_QRIteration(),
CUSOLVER_DivideAndConquer(),
for alg in (ROCSOLVER_QRIteration(),
ROCSOLVER_DivideAndConquer(),
)
A = ROCArray(randn(rng, T, m, m))
A = A * A'
Expand Down Expand Up @@ -64,18 +66,40 @@ end
end
end

@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
(Float32, Float64,
ComplexF32,
ComplexF64)
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 4
V = qr_compact(ROCArray(randn(rng, T, m, m)))[1]
D = Diagonal([0.9, 0.3, 0.1, 0.01])
A = V * D * V'
A = (A + A') / 2
alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2))
alg = TruncatedAlgorithm(ROCSOLVER_QRIteration(), truncrank(2))
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
end=#

@testset "eigh for Diagonal{$T}" for T in BLASFloats
rng = StableRNG(123)
m = 54
Ad = randn(rng, T, m)
Ad .+= conj.(Ad)
A = Diagonal(ROCArray(Ad))
atol = sqrt(eps(real(T)))

D, V = @constinferred eigh_full(A)
@test D isa Diagonal{real(T)} && size(D) == size(A)
@test V isa Diagonal{T} && size(V) == size(A)
@test A * V ≈ V * D

D2 = @constinferred eigh_vals(A)
@test D2 isa AbstractVector{real(T)} && length(D2) == m
@test diagview(D) ≈ D2

# TODO partialsortperm
#=A2 = Diagonal(ROCArray(T[0.9, 0.3, 0.1, 0.01]))
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
@test diagview(D2) ≈ diagview(A2)[1:2]
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=#
end
50 changes: 48 additions & 2 deletions test/amd/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ using Test
using TestExtras
using StableRNGs
using AMDGPU
using LinearAlgebra

include(joinpath("..", "utilities.jl"))

@testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)

@testset "lq_compact! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 54
for n in (37, m, 63)
Expand Down Expand Up @@ -65,7 +68,7 @@ include(joinpath("..", "utilities.jl"))
end
end

@testset "lq_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "lq_full! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 54
for n in (37, m, 63)
Expand Down Expand Up @@ -115,3 +118,46 @@ end
@test_throws ArgumentError lq_full!(copy!(Ac, A), (L, Q); blocksize = 8)
end
end

@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in BLASFloats
rng = StableRNG(123)
atol = eps(real(T))^(3 / 4)
for m in (54, 0)
Ad = ROCArray(randn(rng, T, m))
A = Diagonal(Ad)

# compact
L, Q = @constinferred lq_compact(A)
@test Q isa Diagonal{T} && size(Q) == (m, m)
@test L isa Diagonal{T} && size(L) == (m, m)
@test L * Q ≈ A
@test isunitary(Q)

# compact and positive
Lp, Qp = @constinferred lq_compact(A; positive = true)
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
@test Lp isa Diagonal{T} && size(Lp) == (m, m)
@test Lp * Qp ≈ A
@test isunitary(Qp)
@test all(isposdef.(diagview(Lp)))

# full
L, Q = @constinferred lq_full(A)
@test Q isa Diagonal{T} && size(Q) == (m, m)
@test L isa Diagonal{T} && size(L) == (m, m)
@test L * Q ≈ A
@test isunitary(Q)

# full and positive
Lp, Qp = @constinferred lq_full(A; positive = true)
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
@test Lp isa Diagonal{T} && size(Lp) == (m, m)
@test Lp * Qp ≈ A
@test isunitary(Qp)
@test all(isposdef.(diagview(Lp)))

# null
N = @constinferred lq_null(A)
@test N isa AbstractMatrix{T} && size(N) == (0, m)
end
end
50 changes: 47 additions & 3 deletions test/amd/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ using Test
using TestExtras
using StableRNGs
using AMDGPU
using LinearAlgebra

include(joinpath("..", "utilities.jl"))

eltypes = (Float32, Float64, ComplexF32, ComplexF64)
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)

@testset "qr_compact! and qr_null! for T = $T" for T in eltypes
@testset "qr_compact! and qr_null! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 54
for n in (37, m, 63)
Expand Down Expand Up @@ -68,7 +69,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64)
end
end

@testset "qr_full! for T = $T" for T in eltypes
@testset "qr_full! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 63
for n in (37, m, 63)
Expand Down Expand Up @@ -121,3 +122,46 @@ end
@test_throws ArgumentError qr_full!(copy!(Ac, A), (Q, R); blocksize = 8)
end
end

@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in BLASFloats
rng = StableRNG(123)
atol = eps(real(T))^(3 / 4)
for m in (54, 0)
Ad = ROCArray(randn(rng, T, m))
A = Diagonal(Ad)

# compact
Q, R = @constinferred qr_compact(A)
@test Q isa Diagonal{T} && size(Q) == (m, m)
@test R isa Diagonal{T} && size(R) == (m, m)
@test Q * R ≈ A
@test isunitary(Q)

# compact and positive
Qp, Rp = @constinferred qr_compact(A; positive = true)
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
@test Rp isa Diagonal{T} && size(Rp) == (m, m)
@test Qp * Rp ≈ A
@test isunitary(Qp)
@test all(isposdef.(diagview(Rp)))

# full
Q, R = @constinferred qr_full(A)
@test Q isa Diagonal{T} && size(Q) == (m, m)
@test R isa Diagonal{T} && size(R) == (m, m)
@test Q * R ≈ A
@test isunitary(Q)

# full and positive
Qp, Rp = @constinferred qr_full(A; positive = true)
@test Qp isa Diagonal{T} && size(Qp) == (m, m)
@test Rp isa Diagonal{T} && size(Rp) == (m, m)
@test Qp * Rp ≈ A
@test isunitary(Qp)
@test all(isposdef.(diagview(Rp)))

# null
N = @constinferred qr_null(A)
@test N isa AbstractMatrix{T} && size(N) == (m, 0)
end
end
Loading