diff --git a/Project.toml b/Project.toml index 934c0ceb..a5afbc17 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" @@ -19,6 +21,8 @@ MatrixAlgebraKitAMDGPUExt = "AMDGPU" MatrixAlgebraKitCUDAExt = "CUDA" MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra" MatrixAlgebraKitGenericSchurExt = "GenericSchur" +MatrixAlgebraKitEnzymeExt = "Enzyme" +MatrixAlgebraKitMooncakeExt = "Mooncake" [compat] AMDGPU = "2" @@ -28,8 +32,11 @@ ChainRulesTestUtils = "1" CUDA = "5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" +Enzyme = "0.13.96" +EnzymeTestUtils = "0.2.3" JET = "0.9, 0.10" LinearAlgebra = "1" +Mooncake = "0.4.167" SafeTestsets = "0.1" StableRNGs = "1" Test = "1" @@ -40,7 +47,9 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -48,4 +57,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Enzyme", "EnzymeTestUtils", "Mooncake"] diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl new file mode 100644 index 00000000..a80dce67 --- /dev/null +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -0,0 +1,669 @@ +module MatrixAlgebraKitEnzymeExt + +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc! +using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward! +using MatrixAlgebraKit: svd_pushforward! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward! +using Enzyme +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules +using LinearAlgebra + +@inline EnzymeRules.inactive_type(v::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = true + + +# two-argument factorizations like LQ, QR, EIG +for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!), + (qr_compact!, qr_pullback!, qr_pushforward!), + (lq_full!, lq_pullback!, lq_pushforward!), + (lq_compact!, lq_pullback!, lq_pushforward!), + (eig_full!, eig_pullback!, eig_pushforward!), + (left_polar!, left_polar_pullback!, left_polar_pushforward!), + (right_polar!, right_polar_pullback!, right_polar_pushforward!), + ) + @eval begin + function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + cache_arg = nothing + # form cache if needed + cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing + func.val(A.val, arg.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? arg.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg)) + end + function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs...) where {RT} + cache_A, cache_arg = cache + argval = arg.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂arg = isa(arg, Const) ? nothing : arg.dval + if !isa(A, Const) && !isa(arg, Const) + A.dval .= zero(eltype(Aval)) + $pb(A.dval, A.val, argval, ∂arg; kwargs...) + end + !isa(arg, Const) && make_zero!(arg.dval) + return (nothing, nothing, nothing) + end + function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + ret = func.val(A.val, arg.val, alg.val; kwargs...) + arg1, arg2 = ret + m, n = size(A.val) + + if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const) + darg1, darg2 = arg.dval + dA = A.dval + darg1, darg2 = $pf(dA, A.val, ret, arg.dval) + dA .= zero(eltype(A.val)) + shadow = (darg1, darg2) + elseif isa(A, Const) && !!isa(arg, Union{Duplicated, DuplicatedNoNeed}) + make_zero!(arg.dval) + shadow = arg.dval + end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return ret + else + return nothing + end + end + end +end + +for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pushforward!), + (lq_null!, lq_null_pullback!, lq_null_pushforward!), + ) + @eval begin + function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:AbstractMatrix}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + cache_arg = nothing + # form cache if needed + cache_A = nothing #copy(A.val) + func.val(copy(A.val), arg.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? arg.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg)) + end + + function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:AbstractMatrix}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(arg.val), + rank_atol::Real=tol, + gauge_atol::Real=tol, + kwargs...) where {RT} + cache_A, cache_arg = cache + Aval = isnothing(cache_A) ? A.val : cache_A + if !isa(A, Const) && !isa(arg, Const) + A.dval .= zero(eltype(A.val)) + $pb(A.dval, A.val, arg.val, arg.dval; kwargs...) + end + return (nothing, nothing, nothing) + end + function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:AbstractMatrix}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + ret = func.val(A.val, arg.val, alg.val; kwargs...) + + if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const) + darg = arg.dval + dA = A.dval + $pf(dA, A.val, arg.val, darg) + shadow = darg + elseif isa(A, Const) && !!isa(arg, Union{Duplicated, DuplicatedNoNeed}) + make_zero!(arg.dval) + shadow = arg.dval + end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return ret + else + return nothing + end + end + end +end + + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(svd_compact!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing + shadow = if EnzymeRules.needs_shadow(config) + svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval) + else + nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return ret + else + return nothing + end +end + +# TODO +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(svd_full!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing + shadow = if EnzymeRules.needs_shadow(config) + svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval) + else + nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return ret + else + return nothing + end +end +for f in (:svd_compact!, :svd_full!) + @eval begin + function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + # form cache if needed + cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val) : nothing + cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(A) <: Const)) ? copy(A.val) : nothing + func.val(A.val, USVᴴ.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? USVᴴ.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? USVᴴ.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ)) + end + function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs...) where {RT} + cache_A, cache_USVᴴ = cache + USVᴴval = !isnothing(cache_USVᴴ) ? cache_USVᴴ : USVᴴ.val + ∂USVᴴ = isa(USVᴴ, Const) ? nothing : USVᴴ.dval + if !isa(A, Const) && !isa(USVᴴ, Const) + A.dval .= zero(eltype(A.dval)) + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴval, ∂USVᴴ; kwargs...) + end + if !isa(USVᴴ, Const) + make_zero!(USVᴴ.dval) + end + return (nothing, nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; + kwargs..., + ) where {RT, T<:Real} + # form cache if needed + cache_A = copy(A.val) + svd_compact!(A.val, USVᴴ.val, alg.val.alg) + cache_USVᴴ = copy.(USVᴴ.val) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc) + ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind) + primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing + shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const) + dU, dS, dVᴴ = USVᴴ.dval + dStrunc = Diagonal(diagview(dS)[ind]) + dUtrunc = dU[:, ind] + dVᴴtrunc = dVᴴ[ind, :] + (dUtrunc, dStrunc, dVᴴtrunc) + else + (nothing, nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? (shadow_USVᴴ..., ϵ.dval) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc!)}, + dret::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; + kwargs...) where {RT, T<:Real} + cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache + U, S, Vᴴ = cache_USVᴴ + dU, dS, dVᴴ = shadow_USVᴴ + if !isa(A, Const) && !isa(USVᴴ, Const) + A.dval .= zero(eltype(A.val)) + A.dval .= MatrixAlgebraKit.svd_pullback!(A.dval, A.val, (U, S, Vᴴ), shadow_USVᴴ, ind; kwargs...) + end + if !isa(USVᴴ, Const) + make_zero!(USVᴴ.dval) + end + if !isa(ϵ, Const) + make_zero!(ϵ.dval) + end + return (nothing, nothing, nothing, nothing) +end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(eigh_vals!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + Dmat, V = eigh_full(A.val; kwargs...) + if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const) + ∂K = inv(V) * A.dval * V + ∂Kdiag = diag(∂K) + D.dval .= real.(copy(∂Kdiag)) + A.dval .= zero(eltype(A.val)) + shadow = D.dval + elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed}) + make_zero!(D.dval) + shadow = D.dval + end + eigh_vals!(A.val, zeros(real(eltype(A.val)), size(A.val, 1))) + D.val .= diagview(Dmat) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(Dmat.diag, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return Dmat.diag + else + return nothing + end +end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(eigh_full!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + Dmat, V = func.val(A.val, DV.val; kwargs...) + if isa(A, Const) || all(iszero, A.dval) + make_zero!(DV.dval[1]) + make_zero!(DV.dval[2]) + make_zero!(A.dval) + shadow = (DV.dval[1], DV.dval[2]) + else + ∂K = inv(V) * A.dval * V + ∂Kdiag = diagview(∂K) + ∂Ddiag = diagview(DV.dval[1]) + ∂Ddiag .= real.(∂Kdiag) + D = diagview(Dmat) + dDD = transpose(D) .- D + ∂K ./= dDD + ∂Kdiag .= zero(eltype(V)) + mul!(DV.dval[2], V, ∂K, 1, 0) + shadow = DV.dval[2] + A.dval .= zero(eltype(A.val)) + shadow = (Diagonal(∂Ddiag), DV.dval[2]) + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated((Dmat, V), shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return (Dmat, V) + else + return nothing + end +end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(eig_vals!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + Dval, V = eig_full(A.val, alg.val; kwargs...) + if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const) + ∂K = inv(V) * A.dval * V + ∂Kdiag = diag(∂K) + D.dval .= copy(∂Kdiag) + A.dval .= zero(eltype(A.val)) + shadow = D.dval + elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed}) + make_zero!(D.dval) + shadow = D.dval + end + eig_vals!(A.val, zeros(complex(eltype(A.val)), size(A.val, 1))) + D.val .= diagview(Dval) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(Dmat.diag, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return Dmat.diag + else + return nothing + end +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_trunc!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; + kwargs..., + ) where {RT, T} + # form cache if needed + cache_A = copy(A.val) + MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg) + cache_DV = copy.(DV.val) + DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc) + ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind) + primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing + shadow_DV = if !isa(A, Const) && !isa(DV, Const) + dD, dV = DV.dval + dDtrunc = Diagonal(diagview(dD)[ind]) + dVtrunc = dV[:, ind] + (dDtrunc, dVtrunc) + else + (nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_trunc!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; + kwargs...) where {RT, T} + cache_A, cache_DV, cache_dDVtrunc, ind = cache + D, V = cache_DV + dD, dV = cache_dDVtrunc + if !isa(A, Const) && !isa(DV, Const) + A.dval .= zero(eltype(A.val)) + A.dval .= MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, (D, V), (dD, dV), ind; kwargs...) + end + if !isa(DV, Const) + make_zero!(DV.dval) + end + if !isa(ϵ, Const) + make_zero!(ϵ.dval) + end + return (nothing, nothing, nothing, nothing) +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_trunc!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; + kwargs..., + ) where {RT, T} + # form cache if needed + cache_A = copy(A.val) + eig_full!(A.val, DV.val, alg.val.alg) + cache_DV = copy.(DV.val) + DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc) + ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind) + primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing + shadow_DV = if !isa(A, Const) && !isa(DV, Const) + dD, dV = DV.dval + dDtrunc = Diagonal(diagview(dD)[ind]) + dVtrunc = dV[:, ind] + (dDtrunc, dVtrunc) + else + (nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_trunc!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; + kwargs...) where {RT, T} + cache_A, cache_DV, cache_dDVtrunc, ind = cache + D, V = cache_DV + dD, dV = cache_dDVtrunc + if !isa(A, Const) && !isa(DV, Const) + A.dval .= zero(eltype(A.val)) + A.dval .= MatrixAlgebraKit.eig_pullback!(A.dval, A.val, (D, V), (dD, dV), ind; kwargs...) + end + if !isa(DV, Const) + make_zero!(DV.dval) + end + if !isa(ϵ, Const) + make_zero!(ϵ.dval) + end + return (nothing, nothing, nothing, nothing) +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_full!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + # form cache if needed + cache_DV = nothing + cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + func.val(A.val, DV.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? DV.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? DV.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV)) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_full!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + + cache_A, cache_DV = cache + DVval = !isnothing(cache_DV) ? cache_DV : DV.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂DV = isa(DV, Const) ? nothing : DV.dval + if !isa(A, Const) && !isa(DV, Const) + Dmat, V = DVval + ∂Dmat, ∂V = ∂DV + A.dval .= zero(eltype(Aval)) + MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DVval, ∂DV; kwargs...) + A.dval .*= 2 + diagview(A.dval) ./= 2 + for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2) + if i > j + A.dval[i, j] = zero(eltype(A.dval)) + end + end + end + if !isa(DV, Const) + make_zero!(DV.dval) + end + return (nothing, nothing, nothing) +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_vals!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + cache_D = nothing + cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + func.val(A.val, D.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? D.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing + # form cache if needed + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_vals!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + + cache_A, cache_D = cache + Dval = !isnothing(cache_D) ? cache_D : D.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂D = isa(D, Const) ? nothing : D.dval + if !isa(A, Const) && !isa(D, Const) + _, V = eig_full(Aval, alg.val) + A.dval .= zero(eltype(Aval)) + PΔV = V' \ Diagonal(D.dval) + if eltype(A.dval) <: Real + ΔAc = PΔV * V' + A.dval .+= real.(ΔAc) + else + mul!(A.dval, PΔV, V', 1, 0) + end + end + if !isa(D, Const) + make_zero!(D.dval) + end + return (nothing, nothing, nothing) +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_vals!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + cache_D = nothing + cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + func.val(A.val, D.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? D.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing + # form cache if needed + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_vals!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + + cache_A, cache_D = cache + Dval = !isnothing(cache_D) ? cache_D : D.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂D = isa(D, Const) ? nothing : D.dval + if !isa(A, Const) && !isa(D, Const) + _, V = eigh_full(Aval, alg.val) + A.dval .= zero(eltype(Aval)) + mul!(A.dval, V * Diagonal(real(∂D)), V', 1, 0) + A.dval .*= 2 + diagview(A.dval) ./= 2 + for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2) + if i > j + A.dval[i, j] = zero(eltype(A.dval)) + end + end + end + if !isa(D, Const) + make_zero!(D.dval) + end + return (nothing, nothing, nothing) +end + +end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl new file mode 100644 index 00000000..17971578 --- /dev/null +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -0,0 +1,300 @@ +module MatrixAlgebraKitMooncakeExt + +using Mooncake +using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive +using MatrixAlgebraKit +using MatrixAlgebraKit: inv_safe, diagview +using MatrixAlgebraKit: svd_pushforward! +using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward! +using LinearAlgebra + +# two-argument factorizations like LQ, QR, EIG +for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoint), + (qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint), + (lq_full!, lq_pullback!, lq_pushforward!, :dlq_adjoint), + (lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint), + (eig_full!, eig_pullback!, eig_pushforward!, :deig_adjoint), + (eigh_full!, eigh_pullback!, eigh_pushforward!, :deigh_adjoint), + (left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint), + (right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint), + ) + + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + dA .= zero(eltype(A)) + args = Mooncake.primal(args_dargs) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + function $adj(::Mooncake.NoRData) + dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...) + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...) + darg1 .= zero(eltype(arg1)) + darg2 .= zero(eltype(arg2)) + return Mooncake.CoDual(args, dargs), $adj + end + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + args = Mooncake.primal(args_dargs) + args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2)) + dA .= zero(eltype(A)) + return Mooncake.Dual(args, dargs) + end + end +end + +for (f, f_full, pb, pf, adj) in ((qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint), + (lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + Ac = MatrixAlgebraKit.copy_input($f_full, A) + arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg)) + arg = $f(Ac, arg, Mooncake.primal(alg_dalg)) + function $adj(::Mooncake.NoRData) + dA .= zero(eltype(A)) + $pb(dA, A, arg, darg; kwargs...) + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return arg_darg, $adj + end + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, arg_darg::Dual{<:AbstractMatrix}, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + Ac = MatrixAlgebraKit.copy_input($f_full, A) + arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg)) + arg = $f(Ac, arg, Mooncake.primal(alg_dalg)) + $pf(dA, A, arg, darg; kwargs...) + dA .= zero(dA) + return arg_darg + end + end +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...) + # compute primal + D_ = Mooncake.primal(D_dD) + dD_ = Mooncake.tangent(D_dD) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + D, dD = arrayify(D_, dD_) + nD, V = eig_full(A, alg_dalg.primal; kwargs...) + + # update tangent + tmp = V \ dA + dD .= diagview(tmp * V) + dA .= zero(eltype(dA)) + return Mooncake.Dual(nD.diag, dD_) +end + +function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...) + # compute primal + D_ = Mooncake.primal(D_dD) + dD_ = Mooncake.tangent(D_dD) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + D, dD = arrayify(D_, dD_) + dA .= zero(eltype(dA)) + # update primal + DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...) + V = DV[2] + dD .= zero(eltype(D)) + function deig_vals_adjoint(::Mooncake.NoRData) + PΔV = V' \ Diagonal(dD) + if eltype(dA) <: Real + ΔAc = PΔV * V' + dA .+= real.(ΔAc) + else + mul!(dA, PΔV, V', 1, 0) + end + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return Mooncake.CoDual(DV[1].diag, dD_), deig_vals_adjoint +end +#= +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_full!), AbstractMatrix, Tuple{<:Diagonal, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::CoDual{<:AbstractMatrix}, DV_dDV::CoDual{<:Tuple{<:Diagonal, <:AbstractMatrix}}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + dA .= zero(eltype(A)) + DV = Mooncake.primal(DV_dDV) + dDV = Mooncake.tangent(DV_dDV) + D, dD = arrayify(DV[1], dDV[1]) + V, dV = arrayify(DV[2], dDV[2]) + function deigh_adjoint(::Mooncake.NoRData) + dA = MatrixAlgebraKit.eigh_pullback!(dA, A, (D, V), (dD, dV); kwargs...) + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + DV = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...) + return Mooncake.CoDual(DV, dDV), deigh_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eigh_full!), AbstractMatrix, Tuple{<:Diagonal, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::Dual, DV_dDV::Dual, alg_dalg::Dual; kwargs...) + A, dA = arrayify(A_dA) + DV = Mooncake.primal(DV_dDV) + dDV = Mooncake.tangent(DV_dDV) + D, dD = arrayify(DV[1], dDV[1]) + V, dV = arrayify(DV[2], dDV[2]) + (D, V) = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...) + (dD, dV) = eigh_pushforward!(dA, A, (D, V), (dD, dV); kwargs...) + return Mooncake.Dual(DV, dDV) +end +=# + +@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...) + # compute primal + D_ = Mooncake.primal(D_dD) + dD_ = Mooncake.tangent(D_dD) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + D, dD = arrayify(D_, dD_) + nD, V = eigh_full(A, alg_dalg.primal; kwargs...) + # update tangent + tmp = inv(V) * dA * V + dD .= real.(diagview(tmp)) + D .= nD.diag + dA .= zero(eltype(dA)) + return D_dD +end + +function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...) + # compute primal + D_ = Mooncake.primal(D_dD) + dD_ = Mooncake.tangent(D_dD) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + D, dD = arrayify(D_, dD_) + DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...) + function deigh_vals_adjoint(::Mooncake.NoRData) + mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0) + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return Mooncake.CoDual(DV[1].diag, dD_), deigh_vals_adjoint +end + + +for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal)) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...) + A, dA = arrayify(A_dA) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...) + function dsvd_adjoint(::Mooncake.NoRData) + dA .= zero(eltype(A)) + minmn = min(size(A)...) + if size(U, 2) == size(Vᴴ, 1) == minmn # compact + dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + else # full + vU = view(U, :, 1:minmn) + vS = Diagonal(diagview(S)[1:minmn]) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = view(dS, 1:minmn, 1:minmn) + vdVᴴ = view(dVᴴ, 1:minmn, :) + dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) + end + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + dU .= zero(dU) + dS .= zero(dS) + dVᴴ .= zero(dVᴴ) + return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint + end + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{<:typeof($f)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual; kwargs...) + # compute primal + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + $f(A, USVᴴ, alg_dalg.primal; kwargs...) + + # update tangents + U_, S_, Vᴴ_ = USVᴴ + dU_, dS_, dVᴴ_ = dUSVᴴ + U, dU = arrayify(U_, dU_) + S, dS = arrayify(S_, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_) + minmn = min(size(A)...) + if ($f == svd_compact!) # compact + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...) + else # full + vU = view(U, :, 1:minmn) + vS = view(S, 1:minmn, 1:minmn) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = view(dS, 1:minmn, 1:minmn) + vdVᴴ = view(dVᴴ, 1:minmn, :) + svd_pushforward!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ); kwargs...) + end + return USVᴴ_dUSVᴴ + end + end +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual; kwargs...) + # compute primal + S_ = Mooncake.primal(S_dS) + dS_ = Mooncake.tangent(S_dS) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...) + + # update tangent + S, dS = arrayify(S_, dS_) + copyto!(dS, diag(real.(Vᴴ * dA' * U))) + copyto!(S, diagview(nS)) + dA .= zero(eltype(dA)) + return Mooncake.Dual(nS.diag, dS) +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...) + # compute primal + S_ = Mooncake.primal(S_dS) + dS_ = Mooncake.tangent(S_dS) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + S, dS = arrayify(S_, dS_) + U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...) + S .= diagview(nS) + dS .= zero(eltype(S)) + function dsvd_vals_adjoint(::Mooncake.NoRData) + dA .= U * Diagonal(dS) * Vᴴ + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return S_dS, dsvd_vals_adjoint +end + +end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 4a846f85..1d2f6f61 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -112,4 +112,11 @@ include("pullbacks/eigh.jl") include("pullbacks/svd.jl") include("pullbacks/polar.jl") +include("pushforwards/qr.jl") +include("pushforwards/lq.jl") +include("pushforwards/eig.jl") +include("pushforwards/eigh.jl") +include("pushforwards/polar.jl") +include("pushforwards/svd.jl") + end diff --git a/src/common/view.jl b/src/common/view.jl index c8ae1aa5..0bc7b9ef 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -1,5 +1,5 @@ # diagind: provided by LinearAlgebra.jl -diagview(D::Diagonal) = D.diag +diagview(D::Diagonal) = D.diag diagview(D::AbstractMatrix) = view(D, diagind(D)) # triangularind diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 0cfa6db0..bb766787 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = end function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm) - check_hermitian(A, alg) + #check_hermitian(A, alg) D, V = DV m = size(A, 1) @assert D isa Diagonal && V isa AbstractMatrix diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 3115b3d5..9d0f8cf3 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -46,7 +46,8 @@ function eig_pullback!( Δgauge < gauge_atol || @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) + VᴴΔV ./= conj.(transpose(D) .- D) + diagview(VᴴΔV) .= zero(eltype(VᴴΔV)) if !iszerotangent(ΔDmat) ΔDvec = diagview(ΔDmat) diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index fabc2c2e..1c6de509 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -4,7 +4,7 @@ Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `WP` and cotangent `ΔWP` of `left_polar(A)`. """ -function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP) +function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) # Extract the Polar components W, P = WP @@ -34,7 +34,7 @@ end Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `PWᴴ` and cotangent `ΔPWᴴ` of `right_polar(A)`. """ -function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ) +function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...) # Extract the Polar components P, Wᴴ = PWᴴ diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index c0353a3a..a85b7165 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -26,14 +26,13 @@ function svd_pullback!( degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) - # Extract the SVD components U, Smat, Vᴴ = USVᴴ m, n = size(U, 1), size(Vᴴ, 2) - (m, n) == size(ΔA) || throw(DimensionMismatch()) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) minmn = min(m, n) S = diagview(Smat) - length(S) == minmn || throw(DimensionMismatch()) + length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) r = searchsortedlast(S, rank_atol; rev = true) # rank Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) @@ -44,22 +43,22 @@ function svd_pullback!( UΔU = fill!(similar(U, (r, r)), 0) VΔV = fill!(similar(Vᴴ, (r, r)), 0) if !iszerotangent(ΔU) - m == size(ΔU, 1) || throw(DimensionMismatch()) + m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)")) pU = size(ΔU, 2) - pU > r && throw(DimensionMismatch()) + pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)")) indU = axes(U, 2)[ind] - length(indU) == pU || throw(DimensionMismatch()) + length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))")) UΔUp = view(UΔU, :, indU) mul!(UΔUp, Ur', ΔU) # ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1) end if !iszerotangent(ΔVᴴ) - n == size(ΔVᴴ, 2) || throw(DimensionMismatch()) + n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)")) pV = size(ΔVᴴ, 1) - pV > r && throw(DimensionMismatch()) + pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)")) indV = axes(Vᴴ, 1)[ind] - length(indV) == pV || throw(DimensionMismatch()) + length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))")) VΔVp = view(VΔV, :, indV) mul!(VΔVp, Vᴴr, ΔVᴴ') # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ @@ -82,7 +81,7 @@ function svd_pullback!( ΔS = diagview(ΔSmat) pS = length(ΔS) indS = axes(S, 1)[ind] - length(indS) == pS || throw(DimensionMismatch()) + length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))")) view(diagview(UdΔAV), indS) .+= real.(ΔS) end ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl new file mode 100644 index 00000000..19a43cb9 --- /dev/null +++ b/src/pushforwards/eig.jl @@ -0,0 +1,12 @@ +function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...) + D, V = DV + ΔD, ΔV = ΔDV + iVΔAV = inv(V) * ΔA * V + diagview(ΔD) .= diagview(iVΔAV) + F = 1 ./ (transpose(diagview(D)) .- diagview(D)) + fill!(diagview(F), zero(eltype(F))) + K̇ = F .* iVΔAV + mul!(ΔV, V, K̇, 1, 0) + zero!(ΔA) + return ΔDV +end diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl new file mode 100644 index 00000000..69685b16 --- /dev/null +++ b/src/pushforwards/eigh.jl @@ -0,0 +1,16 @@ +function eigh_pushforward!(dA, A, DV, dDV; kwargs...) + D, V = DV + dD, dV = dDV + tmpV = V \ dA + ∂K = tmpV * V + ∂Kdiag = diag(∂K) + dD.diag .= real.(∂Kdiag) + dDD = transpose(diagview(D)) .- diagview(D) + F = one(eltype(dDD)) ./ dDD + diagview(F) .= zero(eltype(F)) + ∂K .*= F + ∂V = mul!(tmpV, V, ∂K) + copyto!(dV, ∂V) + dA .= zero(eltype(A)) + return (dD, dV) +end diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl new file mode 100644 index 00000000..2d390a59 --- /dev/null +++ b/src/pushforwards/lq.jl @@ -0,0 +1,68 @@ +#=function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) + L, Q = LQ + dL, dQ = dLQ + m = size(L, 1) + n = size(Q, 2) + minmn = min(m, n) + Ld = diagview(L) + p = findlast(>=(rank_atol) ∘ abs, Ld) + + if p == minmn && size(L,1) == size(L,2) # full-rank + invL = inv(L) + dQ .= invL * (dA - dL * Q) + dL = invL * dA * Q' + return (dL, dQ) + end + + n1 = p + n2 = minmn - p + n3 = n - minmn + m1 = p + m2 = m - p + + ##### + Q1 = view(Q, 1:m1, 1:n) # full rank portion + Q2 = view(Q, n1+1:n1+n2, 1:n) + L11 = view(L, 1:m1, 1:n1) + L21 = view(L, (m1+1):m, 1:n1) + + dA1 = view(dA, 1:m1, 1:n) + dA2 = view(dA, (m1+1):m, 1:n) + + dQ1 = view(dQ, 1:n1, 1:n) + dQ2 = view(dQ, n1+1:n1+n2, 1:n) + dL11 = view(dL, 1:m1, 1:n1) + dL21 = view(dL, (m1+1):m, 1:n1) + dL22 = view(dL, (m1+1):m, n1+1:(n1+n2) ) + + # fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need + invL11 = inv(L11) + tmp = invL11 * dA1 * Q1' + Ltmp = tmp + tmp' + diagview(Ltmp) ./= 2 + utLtmp = view(Ltmp, MatrixAlgebraKit.uppertriangularind(Ltmp)) + dL11 .= L11 * Ltmp + dQ1 .= invL11 * dA1 - invL11 * dL11 * Q1 + + dL21 .= (dA2 - L21 * dQ1) * adjoint(Q1) + dQ2 .= -(dQ2 * Q1') * Q1 + if size(Q2, 1) > 0 + dQ2 .+= Q2 * (Q2' * dQ2) + end + if n3 > 0 && size(dQ2, 1) > 0 + # only present for qr_full or rank-deficient qr_compact + Q3 = view(Q, (n1+n2+1):n, 1:n) + dQ2 .+= Q3 * (Q3' * dQ2) + end + if !isempty(dL22) + _, l22 = qr_full(dA2 - L21 * dQ1 - dL12 * Q1, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) + dL22 .= view(l22, 1:size(dL22, 1), 1:size(dL22, 2)) + end + return (dL, dQ) +end=# + +function lq_pushforward!(dA, A, LQ, dLQ; kwargs...) + qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...) +end + +function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl new file mode 100644 index 00000000..2001c41a --- /dev/null +++ b/src/pushforwards/polar.jl @@ -0,0 +1,23 @@ +function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) + W, P = WP + ΔW, ΔP = ΔWP + aWdA = adjoint(W) * ΔA + K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA))) + L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P) + ΔW .= W * K̇ + L̇ + ΔP .= aWdA - K̇*P + MatrixAlgebraKit.zero!(ΔA) + return (ΔW, ΔP) +end + +function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) + P, Wᴴ = PWᴴ + ΔP, ΔWᴴ = ΔPWᴴ + dAW = ΔA * adjoint(Wᴴ) + K̇ = sylvester(P, P, -(dAW - adjoint(dAW))) + L̇ = inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) + ΔWᴴ .= K̇ * Wᴴ + L̇ + ΔP .= dAW - P * K̇ + MatrixAlgebraKit.zero!(ΔA) + return (ΔWᴴ, ΔP) +end diff --git a/src/pushforwards/qr.jl b/src/pushforwards/qr.jl new file mode 100644 index 00000000..26672559 --- /dev/null +++ b/src/pushforwards/qr.jl @@ -0,0 +1,59 @@ +function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) + Q, R = QR + m = size(A, 1) + n = size(A, 2) + minmn = min(m, n) + Rd = diagview(R) + p = findlast(>=(rank_atol) ∘ abs, Rd) + + m1 = p + m2 = minmn - p + m3 = m - minmn + n1 = p + n2 = n - p + + Q1 = view(Q, 1:m, 1:m1) # full rank portion + Q2 = view(Q, 1:m, m1+1:m2+m1) + R11 = view(R, 1:m1, 1:n1) + R12 = view(R, 1:m1, n1+1:n) + + dA1 = view(dA, 1:m, 1:n1) + dA2 = view(dA, 1:m, (n1 + 1):n) + + dQ, dR = dQR + dQ1 = view(dQ, 1:m, 1:m1) + dQ2 = view(dQ, 1:m, m1+1:m2+m1) + dQ3 = minmn+1 < size(dQ, 2) ? view(dQ, :, minmn+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0)) + dR11 = view(dR, 1:m1, 1:n1) + dR12 = view(dR, 1:m1, n1+1:n) + dR22 = view(dR, m1+1:m1+m2, n1+1:n) + + # fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need + invR11 = inv(R11) + tmp = Q1' * dA1 * invR11 + Rtmp = tmp + tmp' + diagview(Rtmp) ./= 2 + ltRtmp = view(Rtmp, MatrixAlgebraKit.lowertriangularind(Rtmp)) + ltRtmp .= zero(eltype(Rtmp)) + dR11 .= Rtmp * R11 + dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 + dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12) + if size(Q2, 2) > 0 + dQ2 .= -Q1 * (Q1' * Q2) + dQ2 .+= Q2 * (Q2' * dQ2) + end + if m3 > 0 && size(Q, 2) > minmn + # only present for qr_full or rank-deficient qr_compact + Q′ = view(Q, :, 1:minmn) + Q3 = view(Q, :, minmn+1:m) + #dQ3 .= Q′ * (Q′' * Q3) + dQ3 .= Q3 + end + if !isempty(dR22) + _, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) + dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2)) + end + return (dQ, dR) +end + +function qr_null_pushforward!(dA, A, N, dN; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(N), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl new file mode 100644 index 00000000..f6abc689 --- /dev/null +++ b/src/pushforwards/svd.jl @@ -0,0 +1,83 @@ +function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; + tol::Real = default_pullback_gaugetol(USVᴴ[2]), + rank_atol::Real = tol, + degeneracy_atol::Real = tol, + gauge_atol::Real = tol + ) + U, Smat, Vᴴ = USVᴴ + m, n = size(U, 1), size(Vᴴ, 2) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) + minmn = min(m, n) + S = diagview(Smat) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + r = searchsortedlast(S, rank_atol; rev = true) # rank + + vΔU = view(ΔU, :, 1:r) + vΔS = view(ΔS, 1:r, 1:r) + vΔVᴴ = view(ΔVᴴ, 1:r, :) + + vU = view(U, :, 1:r) + vS = view(S, 1:r) + vSmat = view(Smat, 1:r, 1:r) + vVᴴ = view(Vᴴ, 1:r, :) + + # compact region + vV = adjoint(vVᴴ) + UΔAV = vU' * ΔA * vV + copyto!(diagview(vΔS), diag(real.(UΔAV))) + F = one(eltype(S)) ./ (transpose(vS) .- vS) + G = one(eltype(S)) ./ (transpose(vS) .+ vS) + diagview(F) .= zero(eltype(F)) + hUΔAV = F .* (UΔAV + UΔAV') ./ 2 + aUΔAV = G .* (UΔAV - UΔAV') ./ 2 + K̇ = hUΔAV + aUΔAV + Ṁ = hUΔAV - aUΔAV + + # check gauge condition + @assert isantihermitian(K̇) + @assert isantihermitian(Ṁ) + K̇diag = diagview(K̇) + for i in 1:length(K̇diag) + @assert K̇diag[i] ≈ (im/2) * imag(diagview(UΔAV)[i])/S[i] + end + + ∂U = vU * K̇ + ∂V = vV * Ṁ + # full component + if size(U, 2) > minmn && size(Vᴴ, 1) > minmn + Uperp = view(U, :, minmn+1:m) + Vᴴperp = view(Vᴴ, minmn+1:n, :) + + aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp) + + UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2))) + fill!(UÃÃV, 0) + view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV + view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' + rhs = vcat( adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U) + superKM = -sylvester(UÃÃV, Smat, rhs) + K̇perp = view(superKM, 1:size(aUAV, 2)) + Ṁperp = view(superKM, size(aUAV, 2)+1:size(aUAV, 1)+size(aUAV, 2)) + ∂U .+= Uperp * K̇perp + ∂V .+= Vperp * Ṁperp + else + ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU*vU') + ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV*vVᴴ) + upper = ImUU * ΔA * vV + lower = ImVV * ΔA' * vU + rhs = vcat(upper, lower) + + Ã = ImUU * A * ImVV + ÃÃ = similar(A, (m + n, m + n)) + fill!(ÃÃ, 0) + view(ÃÃ, (1:m), m .+ (1:n)) .= Ã + view(ÃÃ, m .+ (1:n), 1:m ) .= Ã' + + superLN = -sylvester(ÃÃ, vSmat, rhs) + ∂U += view(superLN, 1:size(upper, 1), :) + ∂V += view(superLN, size(upper, 1)+1:size(upper,1)+size(lower,1), :) + end + copyto!(vΔU, ∂U) + adjoint!(vΔVᴴ, ∂V) + return (ΔU, ΔS, ΔVᴴ) +end diff --git a/test/ad_utils.jl b/test/ad_utils.jl new file mode 100644 index 00000000..11f3b02e --- /dev/null +++ b/test/ad_utils.jl @@ -0,0 +1,26 @@ +function remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' + gaugepart = (gaugepart - gaugepart') / 2 + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_dependence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = V' * ΔV + gaugepart = (gaugepart - gaugepart') / 2 + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end + +precision(::Type{<:Union{Float32,Complex{Float32}}}) = sqrt(eps(Float32)) +precision(::Type{<:Union{Float64,Complex{Float64}}}) = sqrt(eps(Float64)) diff --git a/test/chainrules.jl b/test/chainrules.jl index ba3f0681..76eb84c8 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -6,38 +6,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' - gaugepart = (gaugepart - gaugepart') / 2 - gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ -end -function remove_eiggauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end -function remove_eighgauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D) - ) - gaugepart = V' * ΔV - gaugepart = (gaugepart - gaugepart') / 2 - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end - -precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32)) -precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64)) +include("ad_utils.jl") for f in ( diff --git a/test/cuda/enzyme.jl b/test/cuda/enzyme.jl new file mode 100644 index 00000000..12caad8c --- /dev/null +++ b/test/cuda/enzyme.jl @@ -0,0 +1,326 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using ChainRulesCore +using Enzyme, EnzymeTestUtils +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +is_ci = get(ENV, "CI", "false") == "true" + +ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF64) # Enzyme/#2631 + +include("ad_utils.jl") +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + A = CuArray(randn(rng, T, m, n)) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + @testset for alg in (LAPACK_HouseholderQR(), + LAPACK_HouseholderQR(; positive=true), + ) + #=@testset "forward" begin + @testset "qr_compact" begin + Q, R = qr_compact(A; alg=alg) + ΔQ = randn(rng, T, m, minmn) + ΔR = randn(rng, T, minmn, n) + # run reverse, then forwards, see if we recover + ΔQ2 = copy(ΔQ) + ΔR2 = copy(ΔR) + ΔA = zeros(T, m, n) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔQ .= zero(T) + ΔR .= zero(T) + ΔA2 = copy(ΔA) + ΔQ, ΔR = MatrixAlgebraKit.qr_fwd!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA .= zero(T) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + @test ΔA ≈ ΔA2 atol=atol rtol=rtol + #test_forward(qr_compact, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(ΔQ, ΔR)) + end + #=@testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + ΔN2 = copy(ΔN) + ΔA = zeros(T, m, n) + ΔA = MatrixAlgebraKit.qr_null_pullback!(ΔA, A, N, ΔN) + #test_forward(qr_null, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=ΔN) + end=# + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + ΔQ2 = copy(ΔQ) + ΔR2 = copy(ΔR) + ΔA = zeros(T, m, n) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA2 = copy(ΔA) + ΔQ, ΔR = MatrixAlgebraKit.qr_fwd!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + @test ΔA ≈ ΔA2 atol=atol rtol=rtol + #test_reverse(qr_full, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(ΔQ, ΔR)) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + ΔQ2 = copy(ΔQ) + ΔR2 = copy(ΔR) + ΔA = zeros(T, m, n) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA2 = copy(ΔA) + ΔQ, ΔR = MatrixAlgebraKit.qr_fwd!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + @test ΔA ≈ ΔA2 atol=atol rtol=rtol + #test_forward(qr_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + end=# + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "qr_compact" begin + ΔQ = CuArray(randn(rng, T, m, minmn)) + ΔR = CuArray(randn(rng, T, minmn, n)) + test_reverse(qr_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + @testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * CuArray(randn(rng, T, minmn, max(0, m - minmn))) + test_reverse(qr_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=ΔN) + end + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = CuArray(randn(rng, T, m, m)) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = CuArray(randn(rng, T, m, n)) + test_reverse(qr_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = CuArray(randn(rng, T, m, r) * randn(rng, T, r, n)) + Q, R = qr_compact(Ard, alg) + ΔQ = CuArray(randn(rng, T, m, minmn)) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = CuArray(randn(rng, T, minmn, n)) + view(ΔR, (r + 1):minmn, :) .= 0 + test_reverse(qr_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + end + end + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + A = CuArray(randn(rng, T, m, n)) + @testset for alg in (LAPACK_HouseholderLQ(), + LAPACK_HouseholderLQ(; positive=true), + ) + #@testset "forward: RT $RT, TA $TA" for RT in (Const,Duplicated,DuplicatedNoNeed), TA in (Const,Duplicated,) + #test_forward(lq_full, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + #test_forward(lq_null, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + #test_forward(lq_compact, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + #end + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "lq_compact" begin + ΔL = CuArray(randn(rng, T, m, minmn)) + ΔQ = CuArray(randn(rng, T, minmn, n)) + test_reverse(lq_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ)) + end + @testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = CuArray(randn(rng, T, max(0, n - minmn), minmn)) * Q + test_reverse(lq_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=ΔNᴴ) + end + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = CuArray(randn(rng, T, n, n)) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = CuArray(randn(rng, T, m, n)) + test_reverse(lq_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ)) + end + @testset "lq_compact -- rank-deficient A" begin + r = minmn - 5 + Ard = CuArray(randn(rng, T, m, r) * randn(rng, T, r, n)) + L, Q = lq_compact(Ard, alg) + ΔL = CuArray(randn(rng, T, m, minmn)) + ΔQ = CuArray(randn(rng, T, minmn, n)) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + test_reverse(lq_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ)) + end + end + end + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = CuArray(randn(rng, T, m, m)) + D, V = eig_full(A) + ΔV = CuArray(randn(rng, complex(T), m, m)) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = CuArray(randn(rng, complex(T), m, m)) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) + @testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,) + #=ΔV2 = copy(ΔV) + ΔD2_ = copy(ΔD2) + ΔA = zeros(T, m, m) + ΔA = MatrixAlgebraKit.eig_pullback!(ΔA, copy(A), (D, V), (ΔD2, ΔV)) + ΔA2 = copy(ΔA) + ΔD2_, ΔV = MatrixAlgebraKit.eig_full_fwd!(ΔA, copy(A), (D, V), (ΔD2, ΔV)) + ΔA = MatrixAlgebraKit.eig_pullback!(ΔA, copy(A), (D, V), (ΔD2_, ΔV)) + @test ΔA ≈ ΔA2 atol=atol rtol=rtol=# # TODO + test_forward(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + end + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(eig_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) + test_reverse(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) + end + end +end + +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = CuArray(randn(rng, T, m, m)) + A = A + A' + D, V = eigh_full(A) + D2 = Diagonal(D) + ΔV = CuArray(randn(rng, T, m, m)) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = CuArray(randn(rng, real(T), m, m)) + ΔD2 = Diagonal(CuArray(randn(rng, real(T), m))) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + LAPACK_Bisection(), + LAPACK_MultipleRelativelyRobustRepresentations(), + ) + @testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,) + test_forward(eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + end + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(eigh_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) + test_reverse(eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) + end + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CuArray(randn(rng, T, m, n)) + minmn = min(m, n) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + ) + #=@testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + test_rewind(svd_compact, RT, Duplicated, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ)) + #test_forward(svd_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end=# # TODO + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "svd_compact" begin + U, S, Vᴴ = svd_compact(A) + ΔU = CuArray(randn(rng, T, m, minmn)) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = CuArray(randn(rng, T, minmn, n)) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + test_reverse(svd_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ)) + end + #= + @testset "svd_full" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, m) + ΔS = randn(rng, real(T), m, n) + ΔVᴴ = randn(rng, T, n, n) + remove_svdgauge_dependence!(view(ΔU, :, 1:minmn), view(ΔVᴴ, 1:minmn, :), view(U, :, 1:minmn), view(S, 1:minmn, 1:minmn), view(Vᴴ, 1:minmn, :); degeneracy_atol=atol) + test_reverse(svd_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ)) + end =# # TODO + end + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CuArray(randn(rng, T, m, n)) + @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + m >= n && + test_reverse(left_polar, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + m <= n && + test_reverse(right_polar, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end + end + end +end + +@timedtestset "Orth and null with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CuArray(randn(rng, T, m, n)) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "left_orth" begin + @testset for kind in (:polar, :qr) + n > m && kind == :polar && continue + test_reverse(left_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(kind=kind,)) + end + end + @testset "right_orth" begin + @testset for kind in (:polar, :lq) + n < m && kind == :polar && continue + test_reverse(right_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(kind=kind,)) + end + end + @testset "left_null" begin + ΔN = left_orth(A; kind=:qr)[1] * CuArray(randn(rng, T, min(m, n), m - min(m, n))) + test_reverse(left_null, RT, (A, TA); fkwargs=(; kind=:qr), output_tangent=ΔN, atol=atol, rtol=rtol) + end + @testset "right_null" begin + ΔNᴴ = CuArray(randn(rng, T, n - min(m, n), min(m, n))) * right_orth(A; kind=:lq)[2] + test_reverse(right_null, RT, (A, TA); fkwargs=(; kind=:lq), output_tangent=ΔNᴴ, atol=atol, rtol=rtol) + end + end + end +end diff --git a/test/cuda/mooncake.jl b/test/cuda/mooncake.jl new file mode 100644 index 00000000..00bf3762 --- /dev/null +++ b/test/cuda/mooncake.jl @@ -0,0 +1,210 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using Mooncake, Mooncake.TestUtils, ChainRulesCore +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using CUDA +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +function remove_svdgauge_depence!(ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' + gaugepart = (gaugepart - gaugepart') / 2 + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_depence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_depence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = V' * ΔV + gaugepart = (gaugepart - gaugepart') / 2 + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end + +precision(::Type{<:Union{Float32,Complex{Float32}}}) = sqrt(eps(Float32)) +precision(::Type{<:Union{Float64,Complex{Float64}}}) = sqrt(eps(Float64)) + +for f in + (:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, #:eig_full, :eigh_full, + :eigh_full, :svd_compact, :svd_trunc, :left_polar, :right_polar) + copy_f = Symbol(:copy_, f) + f! = Symbol(f, '!') + @eval begin + function $copy_f(input, alg) + if $f === eigh_full + input = (input + input') / 2 + end + return $f(input, alg) + end + function ChainRulesCore.rrule(::typeof($copy_f), input, alg) + output = MatrixAlgebraKit.initialize_output($f!, input, alg) + if $f === eigh_full + input = (input + input') / 2 + else + input = copy(input) + end + output, pb = ChainRulesCore.rrule($f!, input, output, alg) + return output, x -> (NoTangent(), pb(x)[2], NoTangent()) + end + Mooncake.@from_chainrules Mooncake.DefaultCtx Tuple{typeof($copy_f), AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm} false Mooncake.ReverseMode + end +end + +for f in (:eig_full,)#:eigh_full) + copy_f = Symbol(:copy_, f) + f! = Symbol(f, '!') + @eval begin + function $copy_f(input, alg) + if $f === eigh_full + input = (input + input') / 2 + end + return $f(input, alg) + end + end +end + +@timedtestset "QR AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + # qr_compact + atol = rtol = m * n * precision(T) + A = CUDA.randn(rng, T, m, n) + minmn = min(m, n) + alg = CUSOLVER_HouseholderQR(; positive=true) + @testset for f in (copy_qr_compact, copy_qr_null, copy_qr_full) + Mooncake.TestUtils.test_rule(rng, f, A, alg; mode=Mooncake.ReverseMode) + end + + # rank-deficient A + r = minmn - 5 + A = CUDA.randn(rng, T, m, r) * CUDA.randn(rng, T, r, n) + Q, R = qr_compact(A, alg) + Mooncake.TestUtils.test_rule(rng, copy_qr_compact, A, alg; mode=Mooncake.ReverseMode) + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + # lq_compact + atol = rtol = m * n * precision(T) + A = CUDA.randn(rng, T, m, n) + minmn = min(m, n) + alg = CUSOLVER_HouseholderLQ(; positive=true) + @testset for f in (copy_lq_compact, copy_lq_null, copy_lq_full) + Mooncake.TestUtils.test_rule(rng, f, A, alg; mode=Mooncake.ReverseMode) + end + # rank-deficient A + r = minmn - 5 + A = CUDA.randn(rng, T, m, r) * CUDA.randn(rng, T, r, n) + Mooncake.TestUtils.test_rule(rng, copy_lq_compact, A, alg; mode=Mooncake.ReverseMode) + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in (Float64, Float32, ComplexF64) + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = CUDA.randn(rng, T, m, m) + @testset for alg in (CUSOLVER_Simple(), CUSOLVER_Expert()) + Mooncake.TestUtils.test_rule(rng, copy_eig_full, A, alg; mode=Mooncake.ReverseMode) + end +end +@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = CUDA.randn(rng, T, m, m) + A = A + A' + @testset for alg in (CUSOLVER_QRIteration(), CUSOLVER_DivideAndConquer(), CUSOLVER_Bisection(), + CUSOLVER_MultipleRelativelyRobustRepresentations()) + # copy_eigh_full includes a projector onto the Hermitian part of the matrix + Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode=Mooncake.ReverseMode) + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CUDA.randn(rng, T, m, n) + minmn = min(m, n) + U, S, Vᴴ = svd_compact(A) + @testset for alg in (CUSOLVER_QRIteration(), CUSOLVER_DivideAndConquer()) + Mooncake.TestUtils.test_rule(rng, copy_svd_compact, A, alg; mode=Mooncake.ReverseMode) + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + Mooncake.TestUtils.test_rule(rng, copy_svd_trunc, A, truncalg; mode=Mooncake.ReverseMode) + end + truncalg = TruncatedAlgorithm(alg, trunctol(S[1, 1] / 2)) + r = findlast(>=(S[1, 1] / 2), diagview(S)) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + Mooncake.TestUtils.test_rule(rng, copy_svd_trunc, A, truncalg; mode=Mooncake.ReverseMode) + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CUDA.randn(rng, T, m, n) + @testset for alg in PolarViaSVD.((CUSOLVER_QRIteration(), CUSOLVER_DivideAndConquer())) + m >= n && + Mooncake.TestUtils.test_rule(rng, copy_left_polar, A, alg; mode=Mooncake.ReverseMode) + + m <= n && + Mooncake.TestUtils.test_rule(rng, copy_right_polar, A, alg; mode=Mooncake.ReverseMode) + end + end +end + +#= +@timedtestset "Orth and null with eltype $T" for T in (Float64, ComplexF64, Float32) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + config = Zygote.ZygoteRuleConfig() + test_rrule(config, left_orth, A; + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + test_rrule(config, left_orth, A; fkwargs=(; kind=:qr), + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + m >= n && + test_rrule(config, left_orth, A; fkwargs=(; kind=:polar), + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + + ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + test_rrule(config, left_null, A; fkwargs=(; kind=:qr), output_tangent=ΔN, + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + + test_rrule(config, right_orth, A; + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + test_rrule(config, right_orth, A; fkwargs=(; kind=:lq), + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + m <= n && + test_rrule(config, right_orth, A; fkwargs=(; kind=:polar), + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] + test_rrule(config, right_null, A; fkwargs=(; kind=:lq), output_tangent=ΔNᴴ, + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + end +end +=# diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 00000000..fc4435d1 --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,367 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra +using Enzyme, EnzymeTestUtils +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +is_ci = get(ENV, "CI", "false") == "true" + +#ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631 +ETs = (Float64, ComplexF64) +include("ad_utils.jl") + +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + A = randn(rng, T, m, n) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + @testset for alg in (LAPACK_HouseholderQR(), + LAPACK_HouseholderQR(; positive=true), + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "qr_compact" begin + ΔQ = randn(rng, T, m, minmn) + ΔR = randn(rng, T, minmn, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_forward(qr_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), fdm=fdm) + end + #=@testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + test_forward(qr_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_forward(qr_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), fdm=fdm) + end=# + #=@testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + test_forward(qr_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end=# + end + end + end +end +#= +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + A = randn(rng, T, m, n) + @testset for alg in (LAPACK_HouseholderLQ(), + LAPACK_HouseholderLQ(; positive=true), + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "lq_compact" begin + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ), fdm=fdm) + end + @testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + test_reverse(lq_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=ΔNᴴ) + end + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn(rng, T, m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ), fdm=fdm) + end + @testset "lq_compact -- rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + L, Q = lq_compact(Ard, alg) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ), fdm=fdm) + end + end + end + end +end +=# + + +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + D, V = eig_full(A) + Ddiag = diagview(D) + ΔA = randn(rng, T, m, m) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + @testset for alg in (LAPACK_Simple(), + #LAPACK_Expert(), + ) + @testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_forward(eig_full, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + test_forward(eig_vals, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + end + #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + test_forward(eig_trunc, RT, (A, TA), (truncalg, Const); atol=atol, rtol=rtol) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + test_forward(eig_trunc, RT, (A, TA), (truncalg, Const); atol=atol, rtol=rtol) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end=# + end +end + +function copy_eigh_full(A; kwargs...) + A = (A + A')/2 + eigh_full(A; kwargs...) +end + +function copy_eigh_vals(A; kwargs...) + A = (A + A')/2 + eigh_vals(A; kwargs...) +end + +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + A = A + A' + D, V = eigh_full(A) + D2 = Diagonal(D) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + @testset for alg in (LAPACK_QRIteration(), + #LAPACK_DivideAndConquer(), + #LAPACK_Bisection(), + #LAPACK_MultipleRelativelyRobustRepresentations(), + ) + @testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_forward((A; kwargs...)->copy_eigh_full(A; kwargs...)[1], RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + test_forward((A; kwargs...)->copy_eigh_full(A; kwargs...)[2], RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + test_forward(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + end + #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:m + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + test_forward(eigh_trunc, RT, (A, TA), (truncalg, Const); atol=atol, rtol=rtol) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + test_forward(eigh_trunc, RT, (A, TA), (truncalg, Const); atol=atol, rtol=rtol) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end=# + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in (LAPACK_QRIteration(), + #LAPACK_DivideAndConquer(), + ) + @testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + function mak_pinv_compact(A, alg) + U, S, Vᴴ = svd_compact(A, alg) + return U * inv(S) * Vᴴ + end + function mak_pinv_full(A, alg) + U, S, Vᴴ = svd_full(A, alg) + return U * inv(S) * Vᴴ + end + @testset "svd_compact" begin + A = randn(rng, T, m, m) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_forward(mak_pinv_compact, RT, (A, TA), (alg, Const); atol=atol, rtol=rtol, fdm=fdm) + test_forward((A, alg)->svd_compact(A, alg)[2], RT, (A, TA), (alg, Const); atol=atol, rtol=rtol, fdm=fdm) + end + #=@testset "svd_full" begin # horrible Enzyme error here + A = randn(rng, T, m, m) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_forward(mak_pinv_full, RT, (A, TA), (alg, Const); atol=atol, rtol=rtol, fdm=fdm) + test_forward((A, alg)->svd_full(A, alg)[2], RT, (A, TA), (alg, Const); atol=atol, rtol=rtol, fdm=fdm) + end=# + end + #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "svd_compact" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(svd_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ), fdm=fdm) + end + end + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "svd_trunc" begin + for r in 1:4:minmn + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + # broken due to Enzyme + ϵ = [zero(real(T))] + test_reverse(dummy_svd_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), fdm=fdm) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + # broken due to Enzyme + ϵ = [zero(real(T))] + test_reverse(dummy_svd_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), fdm=fdm) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (copy(U), copy(S), copy(Vᴴ)), (copy(ΔUtrunc), copy(ΔStrunc), copy(ΔVᴴtrunc)), ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end=# + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + m >= n && + test_reverse(left_polar, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + m <= n && + test_reverse(right_polar, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end + end + end +end + +@timedtestset "Orth and null with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "left_orth" begin + @testset for alg in (:polar, :qr,) + n > m && alg == :polar && continue + test_forward(left_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end + end + #=@testset "right_orth" begin + @testset for alg in (:polar, :lq) + n < m && alg == :polar && continue + test_reverse(right_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end + end=# + #=@testset "left_null" begin + test_forward(left_null, RT, (A, TA); fkwargs=(; alg=:qr), atol=atol, rtol=rtol) + end=# + #=@testset "right_null" begin + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg=:lq)[2] + test_reverse(right_null, RT, (A, TA); fkwargs=(; alg=:lq), output_tangent=ΔNᴴ, atol=atol, rtol=rtol) + end=# + end + end +end + diff --git a/test/mooncake.jl b/test/mooncake.jl new file mode 100644 index 00000000..f73c4db4 --- /dev/null +++ b/test/mooncake.jl @@ -0,0 +1,425 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using Mooncake, Mooncake.TestUtils, ChainRulesCore +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +function Mooncake.increment!!(x::Tuple{Matrix{T}, Mooncake.Tangent{@NamedTuple{diag::Vector{T}}}, Matrix{T}}, y::Tuple{Matrix{T}, Mooncake.Tangent{@NamedTuple{diag::Vector{T}}}, Matrix{T}}) where {T<:Real} + return (Mooncake.increment!!(x[1], y[1]), Mooncake.increment!!(x[2], y[2]), Mooncake.increment!!(x[3], y[3])) +end +function Mooncake.increment!!(x::Tuple{Mooncake.Tangent{@NamedTuple{diag::Vector{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}}, Matrix{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}, y::Tuple{Mooncake.Tangent{@NamedTuple{diag::Vector{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}}, Matrix{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}) where {T<:Real} + return (Mooncake.increment!!(x[1], y[1]), Mooncake.increment!!(x[2], y[2])) +end +function Mooncake.increment!!(x::Tuple{Mooncake.Tangent{@NamedTuple{diag::Vector{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}}, Matrix{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}, y::Tuple{Mooncake.Tangent{@NamedTuple{diag::Vector{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}}, Matrix{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}, Vector{T}}) where {T<:Real} + return (Mooncake.increment!!(x[1], y[1]), Mooncake.increment!!(x[2], y[2])) +end + +include("ad_utils.jl") + +make_mooncake_tangent(ΔAelem::T) where {T<:Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) +make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA +make_mooncake_tangent(ΔA::Matrix{T}) where {T<:Complex} = map(make_mooncake_tangent, ΔA) +function make_mooncake_tangent(ΔD::Diagonal{T}) where {T<:Real} + return Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) +end +function make_mooncake_tangent(ΔD::Diagonal{T}) where {T<:Complex} + diag_tangent = map(make_mooncake_tangent, diagview(ΔD)) + return Mooncake.build_tangent(typeof(ΔD), diag_tangent) +end + +ETs = (Float64,)# Float32,)# ComplexF64, ComplexF32) + +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17,)# m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in (LAPACK_HouseholderQR(), + #LAPACK_HouseholderQR(; positive=true), + ) + #=@testset "qr_compact" begin + Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol) + end=# + #=@testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; output_tangent = dN, is_primitive=false, atol=atol, rtol=rtol) + end=# + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + dQ = make_mooncake_tangent(ΔQ) + dR = make_mooncake_tangent(ΔR) + dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR) + #Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + #Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, 1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, minmn+1:m]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + end + #=@testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + dQ = make_mooncake_tangent(ΔQ) + dR = make_mooncake_tangent(ΔR) + dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR) + Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + end=# + end + end +end + +#= +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in (LAPACK_HouseholderLQ(), + LAPACK_HouseholderLQ(; positive=true), + ) + @testset "lq_compact" begin + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol, output_tangent = dLQ) + end + #=@testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; output_tangent = dNᴴ, is_primitive=false, atol=atol, rtol=rtol) + end=# + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn(rng, T, m, n) + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol) + end + #=@testset "lq_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + L, Q = lq_compact(Ard, alg) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol) + end=# + end + end +end +=# +#= +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + DV = eig_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + + dD = make_mooncake_tangent(ΔD2) + dV = make_mooncake_tangent(ΔV) + dDV = Mooncake.build_tangent(typeof((ΔD2,ΔV)), dD, dV) + # compute the dA corresponding to the above dD, dV + @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) + @testset "eig_full" begin + Ah = (A + A')/2 + Mooncake.TestUtils.test_rule(rng, (A, alg) -> eig_full(exp(A), alg)[1], Ah, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, (A, alg) -> eig_full(exp(A), alg)[2], Ah, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "eig_vals" begin + Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; atol=atol, rtol=rtol, is_primitive=false) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc,ΔVtrunc,zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode=Mooncake.ReverseMode, output_tangent=dDVtrunc, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc,ΔVtrunc,zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode=Mooncake.ReverseMode, output_tangent=dDVtrunc, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function copy_eigh_full(A, alg; kwargs...) + A = (A + A')/2 + eigh_full(A, alg; kwargs...) +end + +function copy_eigh_vals(A, alg; kwargs...) + A = (A + A')/2 + eigh_vals(A, alg; kwargs...) +end + +function copy_eigh_trunc(A, alg; kwargs...) + A = (A + A')/2 + eigh_trunc(A, alg; kwargs...) +end + +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + A = A + A' + D, V = eigh_full(A) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + dD = make_mooncake_tangent(ΔD2) + dV = make_mooncake_tangent(ΔV) + dDV = Mooncake.build_tangent(typeof((ΔD2,ΔV)), dD, dV) + Ddiag = diagview(D) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + LAPACK_Bisection(), + LAPACK_MultipleRelativelyRobustRepresentations(), + ) + @testset "eigh_full" begin + Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode=Mooncake.ReverseMode, output_tangent=dDV, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "eigh_vals" begin + Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; is_primitive=false, atol=atol, rtol=rtol) + end + @testset "eigh_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc,ΔVtrunc,zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode=Mooncake.ReverseMode, output_tangent=dDVtrunc, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc,ΔVtrunc,zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode=Mooncake.ReverseMode, output_tangent=dDVtrunc, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function dummy_svd_trunc(A, args...; kwargs...) + U, S, Vᴴ, ϵ = svd_trunc(A, args...; kwargs...) + return U, S, Vᴴ +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in (LAPACK_QRIteration(), + #LAPACK_DivideAndConquer(), + ) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + ΔUref = copy(ΔU) + ΔSref = copy(ΔS2) + ΔVᴴref = copy(ΔVᴴ) + @testset "svd_compact" begin + dS = make_mooncake_tangent(ΔS2) + dU = make_mooncake_tangent(ΔU) + dVᴴ = make_mooncake_tangent(ΔVᴴ) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔU,ΔS2,ΔVᴴ)), dU, dS, dVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴ, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "svd_full" begin + ΔUfull = zeros(T, m, m) + ΔSfull = zeros(real(T), m, n) + ΔVᴴfull = zeros(T, n, n) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔUref + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴref + diagview(ΔSfull)[1:minmn] .= diagview(ΔSref) + dS = make_mooncake_tangent(ΔSfull) + dU = make_mooncake_tangent(ΔUfull) + dVᴴ = make_mooncake_tangent(ΔVᴴfull) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull,ΔSfull,ΔVᴴfull)), dU, dS, dVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴ, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "svd_vals" begin + Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; is_primitive=false, atol=atol, rtol=rtol) + end + #= + @testset "svd_trunc" begin + @testset for r in 1:4:minmn + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + dStrunc = make_mooncake_tangent(ΔStrunc) + dUtrunc = make_mooncake_tangent(ΔUtrunc) + dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) + dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU,ΔS2,ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, dummy_svd_trunc, copy(A), truncalg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴerr, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + dStrunc = make_mooncake_tangent(ΔStrunc) + dUtrunc = make_mooncake_tangent(ΔUtrunc) + dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) + dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU,ΔS2,ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, dummy_svd_trunc, copy(A), truncalg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴerr, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end=# + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) + m >= n && + Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; is_primitive=false, atol=atol, rtol=rtol) + + m <= n && + Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; is_primitive=false, atol=atol, rtol=rtol) + end + end +end + +@timedtestset "Orth and null with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + Mooncake.TestUtils.test_rule(rng, left_orth, A; atol=atol, rtol=rtol, is_primitive=false) + Mooncake.TestUtils.test_rule(rng, right_orth, A; atol=atol, rtol=rtol, is_primitive=false) + + Mooncake.TestUtils.test_rule(rng, (X->left_orth(X; kind=:qr)), A; atol=atol, rtol=rtol, is_primitive=false) + if m >= n + Mooncake.TestUtils.test_rule(rng, (X->left_orth(X; kind=:polar)), A; atol=atol, rtol=rtol, is_primitive=false) + end + + ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, (X->left_null(X; kind=:qr)), A; atol=atol, rtol=rtol, is_primitive=false, output_tangent = dN) + + Mooncake.TestUtils.test_rule(rng, (X->right_orth(X; kind=:lq)), A; atol=atol, rtol=rtol, is_primitive=false) + + if m <= n + Mooncake.TestUtils.test_rule(rng, (X->right_orth(X; kind=:polar)), A; atol=atol, rtol=rtol, is_primitive=false) + end + + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, (X->right_null(X; kind=:lq)), A; atol=atol, rtol=rtol, is_primitive=false, output_tangent = dNᴴ) + end +end +=# diff --git a/test/runtests.jl b/test/runtests.jl index ec255538..2e1f7b0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using SafeTestsets # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - @safetestset "Algorithms" begin + #=@safetestset "Algorithms" begin include("algorithms.jl") end @safetestset "Projections" begin @@ -38,6 +38,14 @@ if !is_buildkite @safetestset "Image and Null Space" begin include("orthnull.jl") end + =# + @safetestset "Mooncake" begin + include("mooncake.jl") + end + #=@safetestset "Enzyme" begin + include("enzyme.jl") + end=# + #= @safetestset "ChainRules" begin include("chainrules.jl") end @@ -52,7 +60,7 @@ if !is_buildkite using JET JET.test_package(MatrixAlgebraKit; target_defined_modules = true) end - end + end=# end using CUDA @@ -81,6 +89,12 @@ if CUDA.functional() @safetestset "CUDA Image and Null Space" begin include("cuda/orthnull.jl") end + #=@safetestset "CUDA Mooncake" begin + include("cuda/mooncake.jl") + end + @safetestset "CUDA Enzyme" begin + include("cuda/enzyme.jl") + end=# end using AMDGPU @@ -106,6 +120,9 @@ if AMDGPU.functional() @safetestset "AMDGPU Image and Null Space" begin include("amd/orthnull.jl") end + #=@safetestset "AMDGPU Enzyme" begin + include("amd/enzyme.jl") + end=# end using GenericLinearAlgebra