Skip to content

Commit 75f8c5e

Browse files
committed
Implement DiagonalAlgorithm for GPU
1 parent 13a1771 commit 75f8c5e

File tree

15 files changed

+420
-161
lines changed

15 files changed

+420
-161
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ end
2727
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2828
return ROCSOLVER_DivideAndConquer(; kwargs...)
2929
end
30+
for f in (:(MatrixAlgebraKit.default_lq_algorithm),
31+
:(MatrixAlgebraKit.default_qr_algorithm),
32+
:(MatrixAlgebraKit.default_eig_algorithm),
33+
:(MatrixAlgebraKit.default_eigh_algorithm),
34+
:(MatrixAlgebraKit.default_svd_algorithm),
35+
)
36+
37+
@eval function $f(::Type{T}; kwargs...) where {S, T <: Diagonal{S, <:StridedROCVector}}
38+
return ROCM_DiagonalAlgorithm(; kwargs...)
39+
end
40+
end
3041

3142
_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
3243
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT
3232
return CUSOLVER_DivideAndConquer(; kwargs...)
3333
end
3434

35+
for f in (:(MatrixAlgebraKit.default_lq_algorithm),
36+
:(MatrixAlgebraKit.default_qr_algorithm),
37+
:(MatrixAlgebraKit.default_eig_algorithm),
38+
:(MatrixAlgebraKit.default_eigh_algorithm),
39+
:(MatrixAlgebraKit.default_svd_algorithm),
40+
)
41+
42+
@eval function $f(::Type{T}; kwargs...) where {S, T <: Diagonal{S, <:StridedCuVector}}
43+
return CUDA_DiagonalAlgorithm(; kwargs...)
44+
end
45+
end
46+
3547
# include for block sector support
3648
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
3749
return CUSOLVER_HouseholderQR(; kwargs...)

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3636
export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration
3737
export LQViaTransposedQR
3838
export PolarViaSVD, PolarNewton
39-
export DiagonalAlgorithm
39+
export DiagonalAlgorithm, GPU_DiagonalAlgorithm, CUDA_DiagonalAlgorithm, ROCM_DiagonalAlgorithm
4040
export NativeBlocked
4141
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar,
4242
CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer

src/implementations/eig.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgori
2828
return nothing
2929
end
3030

31-
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
31+
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithms)
3232
m, n = size(A)
3333
@assert m == n && isdiag(A)
3434
D, V = DV
@@ -40,7 +40,7 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgor
4040
@check_scalar(V, A)
4141
return nothing
4242
end
43-
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
43+
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithms)
4444
m, n = size(A)
4545
@assert m == n && isdiag(A)
4646
@assert D isa AbstractVector
@@ -69,10 +69,10 @@ function initialize_output(::typeof(eig_trunc!), A, alg::TruncatedAlgorithm)
6969
return initialize_output(eig_full!, A, alg.alg)
7070
end
7171

72-
function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm)
72+
function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithms)
7373
return A, similar(A)
7474
end
75-
function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm)
75+
function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithms)
7676
return diagview(A)
7777
end
7878

@@ -123,14 +123,14 @@ end
123123

124124
# Diagonal logic
125125
# --------------
126-
function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal, Diagonal}, alg::DiagonalAlgorithm)
126+
function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal, Diagonal}, alg::DiagonalAlgorithms)
127127
check_input(eig_full!, A, (D, V), alg)
128128
D === A || copy!(D, A)
129129
one!(V)
130130
return D, V
131131
end
132132

133-
function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm)
133+
function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithms)
134134
check_input(eig_vals!, A, D, alg)
135135
Ad = diagview(A)
136136
D === Ad || copy!(D, Ad)

src/implementations/eigh.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::AbstractAl
3838
return nothing
3939
end
4040

41-
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalAlgorithm)
41+
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalAlgorithms)
4242
check_hermitian(A, alg)
4343
@assert isdiag(A)
4444
m = size(A, 1)
@@ -51,7 +51,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
5151
return nothing
5252
end
5353

54-
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm)
54+
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithms)
5555
check_hermitian(A, alg)
5656
@assert isdiag(A)
5757
m = size(A, 1)
@@ -78,10 +78,10 @@ function initialize_output(::typeof(eigh_trunc!), A, alg::TruncatedAlgorithm)
7878
return initialize_output(eigh_full!, A, alg.alg)
7979
end
8080

81-
function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
81+
function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithms)
8282
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A)
8383
end
84-
function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm)
84+
function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithms)
8585
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
8686
end
8787

@@ -137,15 +137,15 @@ end
137137

138138
# Diagonal logic
139139
# --------------
140-
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
140+
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithms)
141141
check_input(eigh_full!, A, DV, alg)
142142
D, V = DV
143143
D === A || (diagview(D) .= real.(diagview(A)))
144144
one!(V)
145145
return D, V
146146
end
147147

148-
function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm)
148+
function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithms)
149149
check_input(eigh_vals!, A, D, alg)
150150
Ad = diagview(A)
151151
D === Ad || (D .= real.(Ad))

src/implementations/lq.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgo
3636
return nothing
3737
end
3838

39-
function check_input(::typeof(lq_full!), A::AbstractMatrix, (L, Q), ::DiagonalAlgorithm)
39+
function check_input(::typeof(lq_full!), A::AbstractMatrix, (L, Q), ::DiagonalAlgorithms)
4040
m, n = size(A)
4141
@assert m == n && isdiag(A)
4242
@assert Q isa Diagonal && L isa Diagonal
@@ -46,10 +46,10 @@ function check_input(::typeof(lq_full!), A::AbstractMatrix, (L, Q), ::DiagonalAl
4646
@check_scalar(Q, A)
4747
return nothing
4848
end
49-
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
49+
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, alg::DiagonalAlgorithms)
5050
return check_input(lq_full!, A, LQ, alg)
5151
end
52-
function check_input(::typeof(lq_null!), A::AbstractMatrix, N, ::DiagonalAlgorithm)
52+
function check_input(::typeof(lq_null!), A::AbstractMatrix, N, ::DiagonalAlgorithms)
5353
m, n = size(A)
5454
@assert m == n && isdiag(A)
5555
@assert N isa AbstractMatrix
@@ -81,7 +81,7 @@ function initialize_output(::typeof(lq_null!), A::AbstractMatrix, ::AbstractAlgo
8181
end
8282

8383
for f! in (:lq_full!, :lq_compact!)
84-
@eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm)
84+
@eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithms)
8585
return similar(A), A
8686
end
8787
end
@@ -124,19 +124,19 @@ function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR)
124124
return Nᴴ
125125
end
126126

127-
function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
127+
function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithms)
128128
check_input(lq_full!, A, LQ, alg)
129129
L, Q = LQ
130130
_diagonal_lq!(A, L, Q; alg.kwargs...)
131131
return L, Q
132132
end
133-
function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
133+
function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithms)
134134
check_input(lq_compact!, A, LQ, alg)
135135
L, Q = LQ
136136
_diagonal_lq!(A, L, Q; alg.kwargs...)
137137
return L, Q
138138
end
139-
function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm)
139+
function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithms)
140140
check_input(lq_null!, A, N, alg)
141141
return _diagonal_lq_null!(A, N; alg.kwargs...)
142142
end

src/implementations/qr.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function check_input(::typeof(qr_null!), A::AbstractMatrix, N, ::AbstractAlgorit
3636
return nothing
3737
end
3838

39-
function check_input(::typeof(qr_full!), A::AbstractMatrix, (Q, R), alg::DiagonalAlgorithm)
39+
function check_input(::typeof(qr_full!), A::AbstractMatrix, (Q, R), alg::DiagonalAlgorithms)
4040
m, n = size(A)
4141
@assert m == n && isdiag(A)
4242
@assert Q isa Diagonal && R isa Diagonal
@@ -46,10 +46,10 @@ function check_input(::typeof(qr_full!), A::AbstractMatrix, (Q, R), alg::Diagona
4646
@check_scalar(R, A)
4747
return nothing
4848
end
49-
function check_input(::typeof(qr_compact!), A::AbstractMatrix, QR, alg::DiagonalAlgorithm)
49+
function check_input(::typeof(qr_compact!), A::AbstractMatrix, QR, alg::DiagonalAlgorithms)
5050
return check_input(qr_full!, A, QR, alg)
5151
end
52-
function check_input(::typeof(qr_null!), A::AbstractMatrix, N, ::DiagonalAlgorithm)
52+
function check_input(::typeof(qr_null!), A::AbstractMatrix, N, ::DiagonalAlgorithms)
5353
m, n = size(A)
5454
@assert m == n && isdiag(A)
5555
@assert N isa AbstractMatrix
@@ -81,7 +81,7 @@ function initialize_output(::typeof(qr_null!), A::AbstractMatrix, ::AbstractAlgo
8181
end
8282

8383
for f! in (:qr_full!, :qr_compact!)
84-
@eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm)
84+
@eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithms)
8585
return A, similar(A)
8686
end
8787
end
@@ -107,19 +107,19 @@ function qr_null!(A::AbstractMatrix, N, alg::LAPACK_HouseholderQR)
107107
return N
108108
end
109109

110-
function qr_full!(A::AbstractMatrix, QR, alg::DiagonalAlgorithm)
110+
function qr_full!(A::AbstractMatrix, QR, alg::DiagonalAlgorithms)
111111
check_input(qr_full!, A, QR, alg)
112112
Q, R = QR
113113
_diagonal_qr!(A, Q, R; alg.kwargs...)
114114
return Q, R
115115
end
116-
function qr_compact!(A::AbstractMatrix, QR, alg::DiagonalAlgorithm)
116+
function qr_compact!(A::AbstractMatrix, QR, alg::DiagonalAlgorithms)
117117
check_input(qr_compact!, A, QR, alg)
118118
Q, R = QR
119119
_diagonal_qr!(A, Q, R; alg.kwargs...)
120120
return Q, R
121121
end
122-
function qr_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm)
122+
function qr_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithms)
123123
check_input(qr_null!, A, N, alg)
124124
_diagonal_qr_null!(A, N; alg.kwargs...)
125125
return N

src/implementations/svd.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::AbstractAlgori
4242
return nothing
4343
end
4444

45-
function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ, ::DiagonalAlgorithm)
45+
function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ, ::DiagonalAlgorithms)
4646
m, n = size(A)
4747
@assert m == n && isdiag(A)
4848
U, S, Vᴴ = USVᴴ
@@ -56,11 +56,11 @@ function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ, ::DiagonalA
5656
return nothing
5757
end
5858
function check_input(
59-
::typeof(svd_compact!), A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm
59+
::typeof(svd_compact!), A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithms
6060
)
6161
return check_input(svd_full!, A, USVᴴ, alg)
6262
end
63-
function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::DiagonalAlgorithm)
63+
function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::DiagonalAlgorithms)
6464
m, n = size(A)
6565
@assert m == n && isdiag(A)
6666
@assert S isa AbstractVector
@@ -93,15 +93,15 @@ function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm)
9393
return initialize_output(svd_compact!, A, alg.alg)
9494
end
9595

96-
function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm)
96+
function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithms)
9797
TA = eltype(A)
9898
TUV = Base.promote_op(sign_safe, TA)
9999
return similar(A, TUV, size(A)), similar(A, real(TA)), similar(A, TUV, size(A))
100100
end
101-
function initialize_output(::typeof(svd_compact!), A::Diagonal, alg::DiagonalAlgorithm)
101+
function initialize_output(::typeof(svd_compact!), A::Diagonal, alg::DiagonalAlgorithms)
102102
return initialize_output(svd_full!, A, alg)
103103
end
104-
function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm)
104+
function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithms)
105105
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
106106
end
107107

@@ -214,11 +214,16 @@ end
214214

215215
# Diagonal logic
216216
# --------------
217-
function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
217+
function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithms)
218218
check_input(svd_full!, A, USVᴴ, alg)
219219
Ad = diagview(A)
220220
U, S, Vᴴ = USVᴴ
221-
p = sortperm(Ad; by = abs, rev = true)
221+
p = if isempty(Ad)
222+
Int[]
223+
else
224+
sortperm(Ad; by = abs, rev = true)
225+
end
226+
222227
zero!(U)
223228
zero!(Vᴴ)
224229
n = size(A, 1)
@@ -239,10 +244,10 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
239244

240245
return U, S, Vᴴ
241246
end
242-
function svd_compact!(A, USVᴴ, alg::DiagonalAlgorithm)
247+
function svd_compact!(A, USVᴴ, alg::DiagonalAlgorithms)
243248
return svd_full!(A, USVᴴ, alg)
244249
end
245-
function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm)
250+
function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithms)
246251
check_input(svd_vals!, A, S, alg)
247252
Ad = diagview(A)
248253
S .= abs.(Ad)

src/interface/decompositions.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ the diagonal structure of the input and outputs.
193193
"""
194194
@algdef DiagonalAlgorithm
195195

196+
"""
197+
GPUDiagonalAlgorithm(; kwargs...)
198+
199+
Algorithm type to denote a native *on-GPU* Julia implementation of the decompositions making use of
200+
the diagonal structure of the input and outputs.
201+
"""
202+
@algdef GPUDiagonalAlgorithm
203+
196204
"""
197205
LQViaTransposedQR(qr_alg)
198206
@@ -296,6 +304,14 @@ const CUSOLVER_SVDAlgorithm = Union{
296304
CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
297305
}
298306

307+
"""
308+
CUDA_DiagonalAlgorithm(; kwargs...)
309+
310+
Algorithm type to denote a native Julia implementation of the decompositions making use of
311+
the diagonal structure of the input and outputs, specialized for CUDA.
312+
"""
313+
@algdef CUDA_DiagonalAlgorithm
314+
299315
# =========================
300316
# ROCSOLVER ALGORITHMS
301317
# =========================
@@ -353,6 +369,14 @@ singular vectors, see also [`gaugefix!`](@ref).
353369

354370
const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi}
355371

372+
"""
373+
ROCM_DiagonalAlgorithm(; kwargs...)
374+
375+
Algorithm type to denote a native Julia implementation of the decompositions making use of
376+
the diagonal structure of the input and outputs, specialized for ROCM.
377+
"""
378+
@algdef ROCM_DiagonalAlgorithm
379+
356380
# Various consts and unions
357381
# -------------------------
358382

@@ -370,11 +394,15 @@ const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm}
370394
const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
371395
const GPU_Randomized = Union{CUSOLVER_Randomized}
372396

397+
const GPU_DiagonalAlgorithm = Union{CUDA_DiagonalAlgorithm, ROCM_DiagonalAlgorithm}
398+
373399
const QRAlgorithms = Union{LAPACK_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}
374400
const LQAlgorithms = Union{LAPACK_HouseholderLQ, LQViaTransposedQR}
375401
const SVDAlgorithms = Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm}
376402
const PolarAlgorithms = Union{PolarViaSVD, PolarNewton}
377403

404+
const DiagonalAlgorithms = Union{DiagonalAlgorithm, GPU_DiagonalAlgorithm}
405+
378406
# ================================
379407
# ORTHOGONALIZATION ALGORITHMS
380408
# ================================

0 commit comments

Comments
 (0)