Skip to content
Merged
20 changes: 8 additions & 12 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,14 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid
return nothing
end

MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
all(A.diag .== adjoint(A.diag))
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; kwargs...) =
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)

MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) =
all(A .== -adjoint(A))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
all(A.diag .== -adjoint(A.diag))
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; kwargs...) =
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
# avoids calling the `StridedMatrix` specialization to avoid scalar indexing,
# use (allocating) fallback instead until we write a dedicated kernel
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = A == A'
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) =
norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = A == -A'
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) =
norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))

function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
axes(A) == axes(B) || throw(DimensionMismatch())
Expand Down
21 changes: 8 additions & 13 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,14 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::Stride
return nothing
end

MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) =
all(A .== adjoint(A))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
all(A.diag .== adjoint(A.diag))
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; kwargs...) =
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)

MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) =
all(A .== -adjoint(A))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
all(A.diag .== -adjoint(A.diag))
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; kwargs...) =
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
# avoids calling the `StridedMatrix` specialization to avoid scalar indexing,
# use (allocating) fallback instead until we write a dedicated kernel
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = A == A'
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) =
norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = A == -A'
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) =
norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))

function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
axes(A) == axes(B) || throw(DimensionMismatch())
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKit
using LinearAlgebra: LinearAlgebra
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
using LinearAlgebra: sylvester, lu!
using LinearAlgebra: sylvester, lu!, diagm
using LinearAlgebra: isposdef, issymmetric
using LinearAlgebra: Diagonal, diag, diagind, isdiag
using LinearAlgebra: UpperTriangular, LowerTriangular
Expand Down
28 changes: 21 additions & 7 deletions src/common/matrixproperties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,13 @@ end

ishermitian_exact(A) = A == A'
ishermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(false); kwargs...)
ishermitian_exact(A::Diagonal) = diagonal_ishermitian_exact(A, Val(false))

function ishermitian_approx(A; atol, rtol, kwargs...)
return norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
end
ishermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(false); kwargs...)
ishermitian_approx(A::Diagonal; kwargs...) = diagonal_ishermitian_approx(A, Val(false); kwargs...)

"""
isantihermitian(A; isapprox_kwargs...)
Expand All @@ -97,16 +100,15 @@ function isantihermitian(A; atol::Real = 0, rtol::Real = 0, kwargs...)
return isantihermitian_approx(A; atol, rtol, kwargs...)
end
end
function isantihermitian_exact(A)
return A == -A'
end
function isantihermitian_exact(A::StridedMatrix; kwargs...)
return strided_ishermitian_exact(A, Val(true); kwargs...)
end
isantihermitian_exact(A) = A == -A'
isantihermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(true); kwargs...)
isantihermitian_exact(A::Diagonal) = diagonal_ishermitian_exact(A, Val(true))

function isantihermitian_approx(A; atol, rtol, kwargs...)
return norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
end
isantihermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(true); kwargs...)
isantihermitian_approx(A::Diagonal; kwargs...) = diagonal_ishermitian_approx(A, Val(true); kwargs...)

# blocked implementation of exact checks for strided matrices
# -----------------------------------------------------------
Expand Down Expand Up @@ -145,7 +147,6 @@ function _ishermitian_exact_offdiag(Al, Au, ::Val{anti}) where {anti}
return true
end


function strided_ishermitian_approx(
A::AbstractMatrix, anti::Val;
blocksize = 32, atol::Real = default_hermitian_tol(A), rtol::Real = 0
Expand Down Expand Up @@ -192,3 +193,16 @@ function _ishermitian_approx_offdiag(Al, Au, ::Val{anti}) where {anti}
end
return ϵ²
end

diagonal_ishermitian_exact(A, ::Val{anti}) where {anti} = all(iszero ∘ (anti ? real : imag), diagview(A))

function diagonal_ishermitian_approx(
A, ::Val{anti}; atol::Real = default_hermitian_tol(A), rtol::Real = 0
) where {anti}
m, n = size(A)
m == n || throw(DimensionMismatch())
init = abs2(zero(eltype(A)))
ϵ² = sum(abs2 ∘ (anti ? real : imag), diagview(A); init)
ϵ²max = oftype(ϵ², rtol > 0 ? max(atol, rtol * norm(A)) : atol)^2
return ϵ² ≤ ϵ²max
end
15 changes: 13 additions & 2 deletions src/implementations/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A)

function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
LinearAlgebra.checksquare(A)
n = Base.require_one_based_indexing(A)
Base.require_one_based_indexing(A)
n = size(A, 1)
B === A || @check_size(B, (n, n))
return nothing
end
function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
LinearAlgebra.checksquare(A)
n = Base.require_one_based_indexing(A)
Base.require_one_based_indexing(A)
n = size(A, 1)
B === A || @check_size(B, (n, n))
return nothing
end
Expand Down Expand Up @@ -61,6 +63,15 @@ function project_isometric!(A::AbstractMatrix, W, alg::AbstractAlgorithm)
return W
end

function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti}
if anti
diagview(A) .= _imimag.(diagview(B))
else
diagview(A) .= real.(diagview(B))
end
return A
end

function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)
n = size(A, 1)
for j in 1:blocksize:n
Expand Down
84 changes: 55 additions & 29 deletions test/amd/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,40 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
m = 54
noisefactor = eps(real(T))^(3 / 4)
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
A = ROCArray(randn(rng, T, m, m))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)
for A in (ROCArray(randn(rng, T, m, m)), Diagonal(ROCArray(randn(rng, T, m))))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)

Bh = project_hermitian(A, alg)
@test ishermitian(Bh)
@test Bh ≈ Ah
@test A == Ac
Bh_approx = Bh + noisefactor * Aa
@test !ishermitian(Bh_approx)
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
Bh = project_hermitian(A, alg)
@test ishermitian(Bh)
@test Bh ≈ Ah
@test A == Ac
Bh_approx = Bh + noisefactor * Aa
# this is still hermitian for real Diagonal: |A - A'| == 0
@test !ishermitian(Bh_approx) || norm(Aa) == 0
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)

Ba = project_antihermitian(A, alg)
@test isantihermitian(Ba)
@test Ba ≈ Aa
@test A == Ac
Ba_approx = Ba + noisefactor * Ah
@test !isantihermitian(Ba_approx)
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
Ba = project_antihermitian(A, alg)
@test isantihermitian(Ba)
@test Ba ≈ Aa
@test A == Ac
Ba_approx = Ba + noisefactor * Ah
@test !isantihermitian(Ba_approx)
# this is never anti-hermitian for real Diagonal: |A - A'| == 0
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0

Bh = project_hermitian!(Ac, alg)
@test Bh === Ac
@test ishermitian(Bh)
@test Bh ≈ Ah
Bh = project_hermitian!(Ac, alg)
@test Bh === Ac
@test ishermitian(Bh)
@test Bh ≈ Ah

copy!(Ac, A)
Ba = project_antihermitian!(Ac, alg)
@test Ba === Ac
@test isantihermitian(Ba)
@test Ba ≈ Aa
copy!(Ac, A)
Ba = project_antihermitian!(Ac, alg)
@test Ba === Ac
@test isantihermitian(Ba)
@test Ba ≈ Aa
end
end
end

Expand All @@ -68,10 +71,33 @@ end

# test that W is closer to A then any other isometry
for k in 1:10
δA = ROCArray(randn(rng, T, m, n))
δA = ROCArray(randn(rng, T, size(A)...))
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) >= norm(A - W)
end
end

m == n && @testset "DiagonalAlgorithm" begin
A = Diagonal(ROCArray(randn(rng, T, m)))
alg = PolarViaSVD(DiagonalAlgorithm())
W = project_isometric(A, alg)
@test isisometric(W)
W2 = project_isometric(W, alg)
@test W2 ≈ W # stability of the projection
@test W * (W' * A) ≈ A

Ac = similar(A)
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
@test W2 === W
@test isisometric(W)

# test that W is closer to A then any other isometry
for k in 1:10
δA = Diagonal(ROCArray(randn(rng, T, m)))
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) > norm(A - W)
@test norm(A - W2) >= norm(A - W)
end
end
end
Expand Down
84 changes: 55 additions & 29 deletions test/cuda/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,40 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
m = 54
noisefactor = eps(real(T))^(3 / 4)
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
A = CuArray(randn(rng, T, m, m))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)
for A in (CuArray(randn(rng, T, m, m)), Diagonal(CuArray(randn(rng, T, m))))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)

Bh = project_hermitian(A, alg)
@test ishermitian(Bh)
@test Bh ≈ Ah
@test A == Ac
Bh_approx = Bh + noisefactor * Aa
@test !ishermitian(Bh_approx)
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
Bh = project_hermitian(A, alg)
@test ishermitian(Bh)
@test Bh ≈ Ah
@test A == Ac
Bh_approx = Bh + noisefactor * Aa
# this is still hermitian for real Diagonal: |A - A'| == 0
@test !ishermitian(Bh_approx) || norm(Aa) == 0
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)

Ba = project_antihermitian(A, alg)
@test isantihermitian(Ba)
@test Ba ≈ Aa
@test A == Ac
Ba_approx = Ba + noisefactor * Ah
@test !isantihermitian(Ba_approx)
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
Ba = project_antihermitian(A, alg)
@test isantihermitian(Ba)
@test Ba ≈ Aa
@test A == Ac
Ba_approx = Ba + noisefactor * Ah
@test !isantihermitian(Ba_approx)
# this is never anti-hermitian for real Diagonal: |A - A'| == 0
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0

Bh = project_hermitian!(Ac, alg)
@test Bh === Ac
@test ishermitian(Bh)
@test Bh ≈ Ah
Bh = project_hermitian!(Ac, alg)
@test Bh === Ac
@test ishermitian(Bh)
@test Bh ≈ Ah

copy!(Ac, A)
Ba = project_antihermitian!(Ac, alg)
@test Ba === Ac
@test isantihermitian(Ba)
@test Ba ≈ Aa
copy!(Ac, A)
Ba = project_antihermitian!(Ac, alg)
@test Ba === Ac
@test isantihermitian(Ba)
@test Ba ≈ Aa
end
end
end

Expand All @@ -68,10 +71,33 @@ end

# test that W is closer to A then any other isometry
for k in 1:10
δA = CuArray(randn(rng, T, m, n))
δA = CuArray(randn(rng, T, size(A)...))
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) >= norm(A - W)
end
end

m == n && @testset "DiagonalAlgorithm" begin
A = Diagonal(CuArray(randn(rng, T, m)))
alg = PolarViaSVD(DiagonalAlgorithm())
W = project_isometric(A, alg)
@test isisometric(W)
W2 = project_isometric(W, alg)
@test W2 ≈ W # stability of the projection
@test W * (W' * A) ≈ A

Ac = similar(A)
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
@test W2 === W
@test isisometric(W)

# test that W is closer to A then any other isometry
for k in 1:10
δA = Diagonal(CuArray(randn(rng, T, m)))
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) > norm(A - W)
@test norm(A - W2) >= norm(A - W)
end
end
end
Expand Down
Loading