From 0811b2e18c2d2c77bf64937245d47692545a78c2 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 9 Sep 2025 19:56:22 +0200 Subject: [PATCH 1/6] WIP Enzyme and Mooncake rules --- Project.toml | 11 +- .../MatrixAlgebraKitEnzymeExt.jl | 713 ++++++++++++++++++ .../MatrixAlgebraKitMooncakeExt.jl | 277 +++++++ src/MatrixAlgebraKit.jl | 7 + src/common/view.jl | 2 +- src/implementations/eigh.jl | 2 +- src/implementations/svd.jl | 54 +- src/pullbacks/eig.jl | 3 +- src/pullbacks/polar.jl | 4 +- src/pullbacks/svd.jl | 19 +- src/pullfwds/eig.jl | 12 + src/pullfwds/eigh.jl | 14 + src/pullfwds/lq.jl | 58 ++ src/pullfwds/polar.jl | 2 + src/pullfwds/qr.jl | 69 ++ src/pullfwds/svd.jl | 21 + test/ad_utils.jl | 26 + test/chainrules.jl | 33 +- test/cuda/enzyme.jl | 326 ++++++++ test/cuda/mooncake.jl | 210 ++++++ test/enzyme.jl | 350 +++++++++ test/mooncake.jl | 400 ++++++++++ test/runtests.jl | 15 + 23 files changed, 2540 insertions(+), 88 deletions(-) create mode 100644 ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl create mode 100644 ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl create mode 100644 src/pullfwds/eig.jl create mode 100644 src/pullfwds/eigh.jl create mode 100644 src/pullfwds/lq.jl create mode 100644 src/pullfwds/polar.jl create mode 100644 src/pullfwds/qr.jl create mode 100644 src/pullfwds/svd.jl create mode 100644 test/ad_utils.jl create mode 100644 test/cuda/enzyme.jl create mode 100644 test/cuda/mooncake.jl create mode 100644 test/enzyme.jl create mode 100644 test/mooncake.jl diff --git a/Project.toml b/Project.toml index 934c0ceb..ec057141 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" @@ -19,6 +21,8 @@ MatrixAlgebraKitAMDGPUExt = "AMDGPU" MatrixAlgebraKitCUDAExt = "CUDA" MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra" MatrixAlgebraKitGenericSchurExt = "GenericSchur" +MatrixAlgebraKitEnzymeExt = "Enzyme" +MatrixAlgebraKitMooncakeExt = "Mooncake" [compat] AMDGPU = "2" @@ -28,8 +32,11 @@ ChainRulesTestUtils = "1" CUDA = "5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" +Enzyme = "0.13.77" +EnzymeTestUtils = "0.2.3" JET = "0.9, 0.10" LinearAlgebra = "1" +Mooncake = "0.4.167" SafeTestsets = "0.1" StableRNGs = "1" Test = "1" @@ -40,7 +47,9 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -48,4 +57,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Enzyme", "EnzymeTestUtils", "Mooncake"] diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl new file mode 100644 index 00000000..db577784 --- /dev/null +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -0,0 +1,713 @@ +module MatrixAlgebraKitEnzymeExt + +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview, inv_safe +using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pullfwd!, lq_pullfwd! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pullfwd!, lq_null_pullfwd! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pullfwd!, eigh_pullfwd! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pullfwd!, right_polar_pullfwd! +using Enzyme +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules +using LinearAlgebra + +@inline EnzymeRules.inactive_type(v::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = true + + +# two-argument factorizations like LQ, QR, EIG +for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pullfwd!), + (qr_compact!, qr_pullback!, qr_pullfwd!), + (lq_full!, lq_pullback!, lq_pullfwd!), + (lq_compact!, lq_pullback!, lq_pullfwd!), + (eig_full!, eig_pullback!, eig_pullfwd!), + (left_polar!, left_polar_pullback!, left_polar_pullfwd!), + (right_polar!, right_polar_pullback!, right_polar_pullfwd!), + ) + @eval begin + function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + cache_arg = nothing + # form cache if needed + cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing + func.val(A.val, arg.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? arg.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg)) + end + function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs...) where {RT} + cache_A, cache_arg = cache + argval = arg.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂arg = isa(arg, Const) ? nothing : arg.dval + if !isa(A, Const) && !isa(arg, Const) + A.dval .= zero(eltype(Aval)) + $pb(A.dval, A.val, argval, ∂arg; kwargs...) + end + !isa(arg, Const) && make_zero!(arg.dval) + return (nothing, nothing, nothing) + end + function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + ret = func.val(A.val, arg.val, alg.val; kwargs...) + arg1, arg2 = ret + m, n = size(A.val) + + if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const) + darg1, darg2 = arg.dval + dA = A.dval + darg1, darg2 = $pf(dA, A.val, ret, arg.dval) + dA .= zero(eltype(A.val)) + shadow = (darg1, darg2) + elseif isa(A, Const) && !!isa(arg, Union{Duplicated, DuplicatedNoNeed}) + make_zero!(arg.dval) + shadow = arg.dval + end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return ret + else + return nothing + end + end + end +end + +for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pullfwd!), + (lq_null!, lq_null_pullback!, lq_null_pullfwd!), + ) + @eval begin + function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:AbstractMatrix}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + cache_arg = nothing + # form cache if needed + cache_A = nothing #copy(A.val) + func.val(copy(A.val), arg.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? arg.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg)) + end + + function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:AbstractMatrix}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(arg.val), + rank_atol::Real=tol, + gauge_atol::Real=tol, + kwargs...) where {RT} + cache_A, cache_arg = cache + Aval = isnothing(cache_A) ? A.val : cache_A + if !isa(A, Const) && !isa(arg, Const) + A.dval .= zero(eltype(A.val)) + $pb(A.dval, A.val, arg.val, arg.dval; kwargs...) + end + return (nothing, nothing, nothing) + end + function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + arg::Annotation{<:AbstractMatrix}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + ret = func.val(A.val, arg.val, alg.val; kwargs...) + + if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const) + darg = arg.dval + dA = A.dval + $pf(dA, A.val, arg.val, darg) + shadow = darg + elseif isa(A, Const) && !!isa(arg, Union{Duplicated, DuplicatedNoNeed}) + make_zero!(arg.dval) + shadow = arg.dval + end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return ret + else + return nothing + end + end + end +end + + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(svd_compact!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing + shadow = if EnzymeRules.needs_shadow(config) + U, S, Vᴴ = ret + V = adjoint(Vᴴ) + ∂S = Diagonal(diag(real.(U' * A.dval * V))) + m, n = size(A.val) + F = one(eltype(S)) ./ ((diagview(S).^2)' .- (diagview(S) .^ 2)) + diagview(F) .= zero(eltype(F)) + invSdiag = zeros(eltype(S), length(S.diag)) + for i in 1:length(S.diag) + @inbounds invSdiag[i] = inv(diagview(S)[i]) + end + invS = Diagonal(invSdiag) + ∂U = U * (F .* (U' * A.dval * V * S + S * Vᴴ * A.dval' * U)) + (diagm(ones(eltype(U), m)) - U*U') * A.dval * V * invS + #∂Vᴴ = (FSdS' * Vᴴ) + (invS * U' * A.dval * (diagm(ones(eltype(U), size(V, 2))) - Vᴴ*V)) + ∂V = V * (F .* (S * U' * A.dval * V + Vᴴ * A.dval' * U * S)) + (diagm(ones(eltype(V), n)) - V*Vᴴ) * A.dval' * U * invS + ∂Vᴴ = similar(Vᴴ) + adjoint!(∂Vᴴ, ∂V) + (∂U, ∂S, ∂Vᴴ) + else + nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return ret + else + return nothing + end +end + +# TODO +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(svd_full!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing + shadow = if EnzymeRules.needs_shadow(config) + fatU, fatS, fatVᴴ = ret + ∂Ufat = zeros(eltype(fatU), size(fatU)) + ∂Sfat = zeros(eltype(fatS), size(fatS)) + ∂Vᴴfat = zeros(eltype(fatVᴴ), size(fatVᴴ)) + m, n = size(A.val) + minmn = min(m, n) + #U = view(fatU, :, 1:minmn) + #S = Diagonal(diagview(fatS)) + #Vᴴ = view(fatVᴴ, 1:minmn, :) + U = fatU + S = fatS + Vᴴ = fatVᴴ + V = adjoint(Vᴴ) + ∂S = Diagonal(diag(real.(U' * A.dval * V))) + diagview(∂Sfat) .= diagview(∂S) + m, n = size(A.val) + F = one(eltype(S)) ./ ((diagview(S).^2)' .- (diagview(S) .^ 2)) + diagview(F) .= zero(eltype(F)) + invSdiag = zeros(eltype(S), size(S)) + for ix in diagind(S) + @inbounds invSdiag[ix] = inv(S[ix]) + end + invS = invSdiag + #FSdS = F .* (∂S * S .+ S * ∂S) + ∂U = U * (F .* (U' * A.dval * V * S + S * Vᴴ * A.dval' * U)) + (diagm(ones(eltype(U), m)) - U*U') * A.dval * V * invS + #view(∂Ufat, :, 1:minmn) .= view(∂U, :, :) + ∂Ufat .= ∂U + + + #∂Vᴴ = (FSdS' * Vᴴ) + (invS * U' * A.dval * (diagm(ones(eltype(U), size(V, 2))) - Vᴴ*V)) + ∂V = V * (F .* (S * U' * A.dval * V + Vᴴ * A.dval' * U * S)) + (diagm(ones(eltype(V), n)) - V*Vᴴ) * A.dval' * U * invS + ∂Vᴴ = similar(Vᴴ) + adjoint!(∂Vᴴ, ∂V) + #view(∂Vᴴfat, 1:minmn, :) .= view(∂Vᴴ, :, :) + ∂Vᴴfat .= ∂Vᴴ + #=view(∂Ufat, :, minmn+1:m) .= zero(eltype(fatU)) + view(∂Vᴴfat, minmn+1:n, :) .= zero(eltype(fatVᴴ)) + view(∂Sfat, minmn+1:m, :) .= zero(eltype(fatVᴴ)) + view(∂Sfat, :, minmn+1:n) .= zero(eltype(fatVᴴ))=# + (∂Ufat, ∂Sfat, ∂Vᴴfat) + else + nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(ret, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return ret + else + return nothing + end +end +for f in (:svd_compact!, :svd_full!) + @eval begin + function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + # form cache if needed + cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val) : nothing + cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(A) <: Const)) ? copy(A.val) : nothing + func.val(A.val, USVᴴ.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? USVᴴ.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? USVᴴ.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ)) + end + function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs...) where {RT} + cache_A, cache_USVᴴ = cache + USVᴴval = !isnothing(cache_USVᴴ) ? cache_USVᴴ : USVᴴ.val + ∂USVᴴ = isa(USVᴴ, Const) ? nothing : USVᴴ.dval + if !isa(A, Const) && !isa(USVᴴ, Const) + A.dval .= zero(eltype(A.dval)) + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴval, ∂USVᴴ; kwargs...) + end + if !isa(USVᴴ, Const) + make_zero!(USVᴴ.dval) + end + return (nothing, nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT, T<:Real} + # form cache if needed + cache_A = copy(A.val) + svd_compact!(A.val, USVᴴ.val, alg.val.alg) + cache_USVᴴ = copy.(USVᴴ.val) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc) + ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind) + primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing + shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const) + dU, dS, dVᴴ = USVᴴ.dval + dStrunc = Diagonal(diagview(dS)[ind]) + dUtrunc = dU[:, ind] + dVᴴtrunc = dVᴴ[ind, :] + (dUtrunc, dStrunc, dVᴴtrunc) + else + (nothing, nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? (shadow_USVᴴ..., ϵ.dval) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc!)}, + dret::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs...) where {RT, T<:Real} + cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache + U, S, Vᴴ = cache_USVᴴ + dU, dS, dVᴴ = shadow_USVᴴ + if !isa(A, Const) && !isa(USVᴴ, Const) + A.dval .= zero(eltype(A.val)) + A.dval .= MatrixAlgebraKit.svd_pullback!(A.dval, A.val, (U, S, Vᴴ), shadow_USVᴴ, ind; kwargs...) + end + if !isa(USVᴴ, Const) + make_zero!(USVᴴ.dval) + end + if !isa(ϵ, Const) + ϵ.dval .= zero(T) + end + return (nothing, nothing, nothing, nothing) +end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(eigh_vals!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + Dmat, V = eigh_full(A.val; kwargs...) + if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const) + ∂K = inv(V) * A.dval * V + ∂Kdiag = diag(∂K) + D.dval .= real.(copy(∂Kdiag)) + A.dval .= zero(eltype(A.val)) + shadow = D.dval + elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed}) + make_zero!(D.dval) + shadow = D.dval + end + eigh_vals!(A.val, zeros(real(eltype(A.val)), size(A.val, 1))) + D.val .= diagview(Dmat) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(Dmat.diag, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return Dmat.diag + else + return nothing + end +end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(eigh_full!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + Dmat, V = func.val(A.val, DV.val; kwargs...) + if isa(A, Const) || all(iszero, A.dval) + make_zero!(DV.dval[1]) + make_zero!(DV.dval[2]) + make_zero!(A.dval) + shadow = (DV.dval[1], DV.dval[2]) + else + ∂K = inv(V) * A.dval * V + ∂Kdiag = diagview(∂K) + ∂Ddiag = diagview(DV.dval[1]) + ∂Ddiag .= real.(∂Kdiag) + D = diagview(Dmat) + dDD = transpose(D) .- D + ∂K ./= dDD + ∂Kdiag .= zero(eltype(V)) + mul!(DV.dval[2], V, ∂K, 1, 0) + shadow = DV.dval[2] + A.dval .= zero(eltype(A.val)) + shadow = (Diagonal(∂Ddiag), DV.dval[2]) + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated((Dmat, V), shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return (Dmat, V) + else + return nothing + end +end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + func::Const{typeof(eig_vals!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + Dval, V = eig_full(A.val, alg.val; kwargs...) + if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const) + ∂K = inv(V) * A.dval * V + ∂Kdiag = diag(∂K) + D.dval .= copy(∂Kdiag) + A.dval .= zero(eltype(A.val)) + shadow = D.dval + elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed}) + make_zero!(D.dval) + shadow = D.dval + end + eig_vals!(A.val, zeros(complex(eltype(A.val)), size(A.val, 1))) + D.val .= diagview(Dval) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(Dmat.diag, shadow) + elseif EnzymeRules.needs_shadow(config) + return shadow + elseif EnzymeRules.needs_primal(config) + return Dmat.diag + else + return nothing + end +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_trunc!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + # form cache if needed + cache_A = copy(A.val) + eigh_full!(A.val, DV.val, alg.val.alg) + cache_DV = copy.(DV.val) + DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc) + ϵ = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind) + primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ) : nothing + shadow_DV = if !isa(A, Const) && !isa(DV, Const) + dD, dV = DV.dval + dDtrunc = Diagonal(diagview(dD)[ind]) + dVtrunc = dV[:, ind] + (dDtrunc, dVtrunc) + else + (nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., zero(ϵ)) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_trunc!)}, + dret, + cache, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; + kwargs...) + cache_A, cache_DV, cache_dDVtrunc, ind = cache + D, V = cache_DV + dD, dV = cache_dDVtrunc + if !isa(A, Const) && !isa(DV, Const) + A.dval .= zero(eltype(A.val)) + A.dval .= MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, (D, V), (dD, dV), ind; kwargs...) + end + if !isa(DV, Const) + make_zero!(DV.dval) + end + return (nothing, nothing, nothing, nothing) +end +#= +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_trunc!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + # form cache if needed + cache_A = copy(A.val) + eig_full!(A.val, DV.val, alg.val.alg) + cache_DV = copy.(DV.val) + DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc) + ϵ = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind) + primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ) : nothing + shadow_DV = if !isa(A, Const) && !isa(DV, Const) + dD, dV = DV.dval + dDtrunc = Diagonal(diagview(dD)[ind]) + dVtrunc = dV[:, ind] + (dDtrunc, dVtrunc) + else + (nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., zero(ϵ)) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_trunc!)}, + dret, + cache, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; + kwargs...) + cache_A, cache_DV, cache_dDVtrunc, ind = cache + D, V = cache_DV + dD, dV = cache_dDVtrunc + if !isa(A, Const) && !isa(DV, Const) + A.dval .= zero(eltype(A.val)) + A.dval .= MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, (D, V), (dD, dV), ind; kwargs...) + end + if !isa(DV, Const) + make_zero!(DV.dval) + end + return (nothing, nothing, nothing, nothing) +end +=# +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_full!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + # form cache if needed + cache_DV = nothing + cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + func.val(A.val, DV.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? DV.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? DV.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV)) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_full!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + + cache_A, cache_DV = cache + DVval = !isnothing(cache_DV) ? cache_DV : DV.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂DV = isa(DV, Const) ? nothing : DV.dval + if !isa(A, Const) && !isa(DV, Const) + Dmat, V = DVval + ∂Dmat, ∂V = ∂DV + A.dval .= zero(eltype(Aval)) + MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DVval, ∂DV; kwargs...) + A.dval .*= 2 + diagview(A.dval) ./= 2 + for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2) + if i > j + A.dval[i, j] = zero(eltype(A.dval)) + end + end + end + if !isa(DV, Const) + make_zero!(DV.dval) + end + return (nothing, nothing, nothing) +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_vals!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + cache_D = nothing + cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + func.val(A.val, D.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? D.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing + # form cache if needed + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eig_vals!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + + cache_A, cache_D = cache + Dval = !isnothing(cache_D) ? cache_D : D.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂D = isa(D, Const) ? nothing : D.dval + if !isa(A, Const) && !isa(D, Const) + _, V = eig_full(Aval, alg.val) + A.dval .= zero(eltype(Aval)) + PΔV = V' \ Diagonal(D.dval) + if eltype(A.dval) <: Real + ΔAc = PΔV * V' + A.dval .+= real.(ΔAc) + else + mul!(A.dval, PΔV, V', 1, 0) + end + end + if !isa(D, Const) + make_zero!(D.dval) + end + return (nothing, nothing, nothing) +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_vals!)}, + ::Type{RT}, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + cache_D = nothing + cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + func.val(A.val, D.val, alg.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? D.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing + # form cache if needed + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D)) +end +function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(eigh_vals!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractMatrix}, + D::Annotation{<:AbstractVector}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + kwargs..., + ) where {RT} + + cache_A, cache_D = cache + Dval = !isnothing(cache_D) ? cache_D : D.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂D = isa(D, Const) ? nothing : D.dval + if !isa(A, Const) && !isa(D, Const) + _, V = eigh_full(Aval, alg.val) + A.dval .= zero(eltype(Aval)) + mul!(A.dval, V * Diagonal(real(∂D)), V', 1, 0) + A.dval .*= 2 + diagview(A.dval) ./= 2 + for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2) + if i > j + A.dval[i, j] = zero(eltype(A.dval)) + end + end + end + if !isa(D, Const) + make_zero!(D.dval) + end + return (nothing, nothing, nothing) +end + +end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl new file mode 100644 index 00000000..8af2a95f --- /dev/null +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -0,0 +1,277 @@ +module MatrixAlgebraKitMooncakeExt + +using Mooncake +using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive +using MatrixAlgebraKit +using MatrixAlgebraKit: inv_safe, diagview +using MatrixAlgebraKit: svd_pullfwd! +using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pullfwd!, lq_pullfwd! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pullfwd!, lq_null_pullfwd! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pullfwd!, eigh_pullfwd! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pullfwd!, right_polar_pullfwd! +using LinearAlgebra + +# two-argument factorizations like LQ, QR, EIG +for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pullfwd!, :dqr_adjoint), + (qr_compact!, qr_pullback!, qr_pullfwd!, :dqr_adjoint), + (lq_full!, lq_pullback!, lq_pullfwd!, :dlq_adjoint), + (lq_compact!, lq_pullback!, lq_pullfwd!, :dlq_adjoint), + (eig_full!, eig_pullback!, eig_pullfwd!, :deig_adjoint), + (eigh_full!, eigh_pullback!, eigh_pullfwd!, :deigh_adjoint), + (left_polar!, left_polar_pullback!, left_polar_pullfwd!, :dleft_polar_adjoint), + (right_polar!, right_polar_pullback!, right_polar_pullfwd!, :dright_polar_adjoint), + ) + + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + dA .= zero(eltype(A)) + args = Mooncake.primal(args_dargs) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + function $adj(::Mooncake.NoRData) + dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...) + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...) + darg1 .= zero(eltype(arg1)) + darg2 .= zero(eltype(arg2)) + return Mooncake.CoDual(args, dargs), $adj + end + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + args = Mooncake.primal(args_dargs) + args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2)) + dA .= zero(eltype(A)) + return Mooncake.Dual(args, dargs) + end + end +end + +for (f, pb, pf, adj) in ((qr_null!, qr_null_pullback!, qr_null_pullfwd!, :dqr_null_adjoint), + (lq_null!, lq_null_pullback!, lq_null_pullfwd!, :dlq_null_adjoint), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + Ac = MatrixAlgebraKit.copy_input(lq_full, A) + arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg)) + arg = $f(Ac, arg, Mooncake.primal(alg_dalg)) + function $adj(::Mooncake.NoRData) + dA .= zero(eltype(A)) + $pb(dA, A, arg, darg; kwargs...) + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return arg_darg, $adj + end + #forward mode not implemented yet + end +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...) + # compute primal + D_ = Mooncake.primal(D_dD) + dD_ = Mooncake.tangent(D_dD) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + D, dD = arrayify(D_, dD_) + nD, V = eig_full(A, alg_dalg.primal; kwargs...) + + # update tangent + tmp = V \ dA + dD .= diagview(tmp * V) + dA .= zero(eltype(dA)) + return Mooncake.Dual(nD.diag, dD_) +end + +function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...) + # compute primal + D_ = Mooncake.primal(D_dD) + dD_ = Mooncake.tangent(D_dD) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + D, dD = arrayify(D_, dD_) + dA .= zero(eltype(dA)) + # update primal + DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...) + V = DV[2] + dD .= zero(eltype(D)) + function deig_vals_adjoint(::Mooncake.NoRData) + PΔV = V' \ Diagonal(dD) + if eltype(dA) <: Real + ΔAc = PΔV * V' + dA .+= real.(ΔAc) + else + mul!(dA, PΔV, V', 1, 0) + end + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return Mooncake.CoDual(DV[1].diag, dD_), deig_vals_adjoint +end +#= +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_full!), AbstractMatrix, Tuple{<:Diagonal, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::CoDual{<:AbstractMatrix}, DV_dDV::CoDual{<:Tuple{<:Diagonal, <:AbstractMatrix}}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + dA .= zero(eltype(A)) + DV = Mooncake.primal(DV_dDV) + dDV = Mooncake.tangent(DV_dDV) + D, dD = arrayify(DV[1], dDV[1]) + V, dV = arrayify(DV[2], dDV[2]) + function deigh_adjoint(::Mooncake.NoRData) + dA = MatrixAlgebraKit.eigh_pullback!(dA, A, (D, V), (dD, dV); kwargs...) + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + DV = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...) + return Mooncake.CoDual(DV, dDV), deigh_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eigh_full!), AbstractMatrix, Tuple{<:Diagonal, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::Dual, DV_dDV::Dual, alg_dalg::Dual; kwargs...) + A, dA = arrayify(A_dA) + DV = Mooncake.primal(DV_dDV) + dDV = Mooncake.tangent(DV_dDV) + D, dD = arrayify(DV[1], dDV[1]) + V, dV = arrayify(DV[2], dDV[2]) + (D, V) = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...) + (dD, dV) = eigh_pullfwd!(dA, A, (D, V), (dD, dV); kwargs...) + return Mooncake.Dual(DV, dDV) +end +=# + +@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...) + # compute primal + D_ = Mooncake.primal(D_dD) + dD_ = Mooncake.tangent(D_dD) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + D, dD = arrayify(D_, dD_) + nD, V = eigh_full(A, alg_dalg.primal; kwargs...) + # update tangent + tmp = inv(V) * dA * V + dD .= real.(diagview(tmp)) + D .= nD.diag + dA .= zero(eltype(dA)) + return D_dD +end + +function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...) + # compute primal + D_ = Mooncake.primal(D_dD) + dD_ = Mooncake.tangent(D_dD) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + D, dD = arrayify(D_, dD_) + DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...) + function deigh_vals_adjoint(::Mooncake.NoRData) + mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0) + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return Mooncake.CoDual(DV[1].diag, dD_), deigh_vals_adjoint +end + + +for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal)) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...) + A, dA = arrayify(A_dA) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...) + function dsvd_adjoint(::Mooncake.NoRData) + dA .= zero(eltype(A)) + minmn = min(size(A)...) + if size(U, 2) == size(Vᴴ, 1) == minmn # compact + dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + else # full + vU = view(U, :, 1:minmn) + vS = Diagonal(diagview(S)[1:minmn]) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = Diagonal(diagview(dS)[1:minmn]) + vdVᴴ = view(dVᴴ, 1:minmn, :) + dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (vdU, vdS, vdVᴴ)) + end + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint + end + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{<:typeof($f)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual; kwargs...) + # compute primal + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + $f(A, USVᴴ, alg_dalg.primal; kwargs...) + + # update tangents + U_, S_, Vᴴ_ = USVᴴ + dU_, dS_, dVᴴ_ = dUSVᴴ + U, dU = arrayify(U_, dU_) + S, dS = arrayify(S_, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_) + (dU, dS, dVᴴ) = svd_pullfwd!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...) + return USVᴴ_dUSVᴴ + end + end +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual; kwargs...) + # compute primal + S_ = Mooncake.primal(S_dS) + dS_ = Mooncake.tangent(S_dS) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...) + + # update tangent + S, dS = arrayify(S_, dS_) + copyto!(dS, diag(real.(Vᴴ * dA' * U))) + copyto!(S, diagview(nS)) + dA .= zero(eltype(dA)) + return Mooncake.Dual(nS.diag, dS) +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...) + # compute primal + S_ = Mooncake.primal(S_dS) + dS_ = Mooncake.tangent(S_dS) + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + S, dS = arrayify(S_, dS_) + U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...) + S .= diagview(nS) + dS .= zero(eltype(S)) + function dsvd_vals_adjoint(::Mooncake.NoRData) + dA .= U * Diagonal(dS) * Vᴴ + return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() + end + return S_dS, dsvd_vals_adjoint +end + +end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 4a846f85..3cda864c 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -112,4 +112,11 @@ include("pullbacks/eigh.jl") include("pullbacks/svd.jl") include("pullbacks/polar.jl") +include("pullfwds/qr.jl") +include("pullfwds/lq.jl") +include("pullfwds/eig.jl") +include("pullfwds/eigh.jl") +include("pullfwds/polar.jl") +include("pullfwds/svd.jl") + end diff --git a/src/common/view.jl b/src/common/view.jl index c8ae1aa5..0bc7b9ef 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -1,5 +1,5 @@ # diagind: provided by LinearAlgebra.jl -diagview(D::Diagonal) = D.diag +diagview(D::Diagonal) = D.diag diagview(D::AbstractMatrix) = view(D, diagind(D)) # triangularind diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 0cfa6db0..bb766787 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = end function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm) - check_hermitian(A, alg) + #check_hermitian(A, alg) D, V = DV m = size(A, 1) @assert D isa Diagonal && V isa AbstractMatrix diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index fed36cd1..50f352b0 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -89,7 +89,7 @@ end function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm) return similar(A, real(eltype(A)), (min(size(A)...),)) end -function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm) +function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm) return initialize_output(svd_compact!, A, alg.alg) end @@ -333,46 +333,25 @@ function _gpu_gesvdj!( ) throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) end -function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) - m, n = size(A) - m ≥ n && return _gpu_gesvd!(A, S, U, Vᴴ) - # both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration) - # if this condition is not met, do the SVD via adjoint - minmn = min(m, n) - Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') - Uᴴ = similar(U') - V = similar(Vᴴ') - if size(U) == (m, m) - _gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) - else - _gpu_gesvd!(Aᴴ, S, V, Uᴴ) - end - length(U) > 0 && adjoint!(U, Uᴴ) - length(Vᴴ) > 0 && adjoint!(Vᴴ, V) - return U, S, Vᴴ -end - # GPU SVD implementation -function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) +function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_full!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ fill!(S, zero(eltype(S))) m, n = size(A) minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end if alg isa GPU_QRIteration isempty(alg.kwargs) || - @warn "GPU_QRIteration does not accept any keyword arguments" - _gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ) + throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) + _gpu_gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) + # elseif alg isa LAPACK_Bisection + # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) + # elseif alg isa LAPACK_Jacobi + # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) else throw(ArgumentError("Unsupported SVD algorithm")) end @@ -389,21 +368,16 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make this controllable using a `gaugefix` keyword argument gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) - # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong - USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - Strunc = diagview(USVᴴtrunc[2]) - # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum - ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this? - return USVᴴtrunc..., ϵ + return first(truncate(svd_trunc!, USVᴴ, alg.trunc)) end -function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) +function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ if alg isa GPU_QRIteration isempty(alg.kwargs) || - @warn "GPU_QRIteration does not accept any keyword arguments" - _gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ) + throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) + _gpu_gesvd!(A, S.diag, U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi @@ -423,8 +397,8 @@ function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) if alg isa GPU_QRIteration isempty(alg.kwargs) || - @warn "GPU_QRIteration does not accept any keyword arguments" - _gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ) + throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) + _gpu_gesvd!(A, S, U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 3115b3d5..9d0f8cf3 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -46,7 +46,8 @@ function eig_pullback!( Δgauge < gauge_atol || @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) + VᴴΔV ./= conj.(transpose(D) .- D) + diagview(VᴴΔV) .= zero(eltype(VᴴΔV)) if !iszerotangent(ΔDmat) ΔDvec = diagview(ΔDmat) diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index fabc2c2e..1c6de509 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -4,7 +4,7 @@ Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `WP` and cotangent `ΔWP` of `left_polar(A)`. """ -function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP) +function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) # Extract the Polar components W, P = WP @@ -34,7 +34,7 @@ end Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `PWᴴ` and cotangent `ΔPWᴴ` of `right_polar(A)`. """ -function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ) +function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...) # Extract the Polar components P, Wᴴ = PWᴴ diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index c0353a3a..a85b7165 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -26,14 +26,13 @@ function svd_pullback!( degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) - # Extract the SVD components U, Smat, Vᴴ = USVᴴ m, n = size(U, 1), size(Vᴴ, 2) - (m, n) == size(ΔA) || throw(DimensionMismatch()) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) minmn = min(m, n) S = diagview(Smat) - length(S) == minmn || throw(DimensionMismatch()) + length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) r = searchsortedlast(S, rank_atol; rev = true) # rank Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) @@ -44,22 +43,22 @@ function svd_pullback!( UΔU = fill!(similar(U, (r, r)), 0) VΔV = fill!(similar(Vᴴ, (r, r)), 0) if !iszerotangent(ΔU) - m == size(ΔU, 1) || throw(DimensionMismatch()) + m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)")) pU = size(ΔU, 2) - pU > r && throw(DimensionMismatch()) + pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)")) indU = axes(U, 2)[ind] - length(indU) == pU || throw(DimensionMismatch()) + length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))")) UΔUp = view(UΔU, :, indU) mul!(UΔUp, Ur', ΔU) # ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1) end if !iszerotangent(ΔVᴴ) - n == size(ΔVᴴ, 2) || throw(DimensionMismatch()) + n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)")) pV = size(ΔVᴴ, 1) - pV > r && throw(DimensionMismatch()) + pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)")) indV = axes(Vᴴ, 1)[ind] - length(indV) == pV || throw(DimensionMismatch()) + length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))")) VΔVp = view(VΔV, :, indV) mul!(VΔVp, Vᴴr, ΔVᴴ') # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ @@ -82,7 +81,7 @@ function svd_pullback!( ΔS = diagview(ΔSmat) pS = length(ΔS) indS = axes(S, 1)[ind] - length(indS) == pS || throw(DimensionMismatch()) + length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))")) view(diagview(UdΔAV), indS) .+= real.(ΔS) end ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA diff --git a/src/pullfwds/eig.jl b/src/pullfwds/eig.jl new file mode 100644 index 00000000..06d2c9e7 --- /dev/null +++ b/src/pullfwds/eig.jl @@ -0,0 +1,12 @@ +function eig_pullfwd!(dA, A, DV, dDV; kwargs...) + D, V = DV + dD, dV = dDV + ∂K = inv(V) * dA * V + ∂Kdiag = diagview(∂K) + dD.diag .= ∂Kdiag + ∂K ./= transpose(diagview(D)) .- diagview(D) + fill!(∂Kdiag, zero(eltype(D))) + mul!(dV, V, ∂K, 1, 0) + dA .= zero(eltype(dA)) + return dDV +end diff --git a/src/pullfwds/eigh.jl b/src/pullfwds/eigh.jl new file mode 100644 index 00000000..91ba0e91 --- /dev/null +++ b/src/pullfwds/eigh.jl @@ -0,0 +1,14 @@ +function eigh_pullfwd!(dA, A, DV, dDV; kwargs...) + tmpV = V \ dA + ∂K = tmpV * V + ∂Kdiag = diag(∂K) + dD.diag .= real.(∂Kdiag) + dDD = transpose(diagview(D)) .- diagview(D) + F = one(eltype(dDD)) ./ dDD + diagview(F) .= zero(eltype(F)) + ∂K .*= F + ∂V = mul!(tmpV, V, ∂K) + copyto!(dV, ∂V) + dA .= zero(eltype(A)) + return (dD, dV) +end diff --git a/src/pullfwds/lq.jl b/src/pullfwds/lq.jl new file mode 100644 index 00000000..02f6b253 --- /dev/null +++ b/src/pullfwds/lq.jl @@ -0,0 +1,58 @@ +function lq_pullfwd!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) + L, Q = LQ + m = size(L, 1) + n = size(Q, 2) + minmn = min(m, n) + Ld = diagview(L) + p = findlast(>=(rank_atol) ∘ abs, Ld) + + n1 = p + n2 = minmn - p + n3 = n - minmn + m1 = p + m2 = m - p + + ##### + Q1 = view(Q, 1:n1, 1:n) # full rank portion + Q2 = view(Q, 1:n1+1:n2+n1, 1:n) + L11 = view(L, 1:m, 1:n1) + L12 = view(L, 1:m1, n1+1:n) + + dA1 = view(dA, 1:m, 1:n1) + dA2 = view(dA, 1:m, (n1 + 1):n) + + dQ, dR = dQR + dQ1 = view(dQ, 1:m, 1:m1) + dQ2 = view(dQ, 1:m, m1+1:m2+m1) + dR11 = view(dR, 1:m1, 1:n1) + dR12 = view(dR, 1:m1, n1+1:n) + dR22 = view(dR, m1+1:m1+m2, n1+1:n) + + # fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need + invR11 = inv(R11) + tmp = Q1' * dA1 * invR11 + Rtmp = tmp + tmp' + diagview(Rtmp) ./= 2 + ltRtmp = view(Rtmp, MatrixAlgebraKit.lowertriangularind(Rtmp)) + #ltRtmp .= zero(eltype(Rtmp)) + dR11 .= Rtmp * R11 + dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 + + dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12) + dQ2 .= Q1 * (Q1' * dQ2) + if size(Q2, 2) > 0 + dQ2 .+= Q2 * (Q2' * dQ2) + end + if m3 > 0 && size(dQ2, 2) > 0 + # only present for qr_full or rank-deficient qr_compact + Q3 = view(Q, 1:m, m1+m2+1:size(Q, 2)) + dQ2 .+= Q3 * (Q3' * dQ2) + end + if !isempty(dR22) + _, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) + dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2)) + end + return (dQ, dR) +end + +function lq_null_pullfwd!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/src/pullfwds/polar.jl b/src/pullfwds/polar.jl new file mode 100644 index 00000000..5e47da5d --- /dev/null +++ b/src/pullfwds/polar.jl @@ -0,0 +1,2 @@ +function left_polar_pullfwd! end +function right_polar_pullfwd! end diff --git a/src/pullfwds/qr.jl b/src/pullfwds/qr.jl new file mode 100644 index 00000000..9b60842f --- /dev/null +++ b/src/pullfwds/qr.jl @@ -0,0 +1,69 @@ +function qr_pullfwd!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) + Q, R = QR + m = size(Q, 1) + n = size(R, 2) + minmn = min(m, n) + Rd = diagview(R) + p = findlast(>=(rank_atol) ∘ abs, Rd) + + m1 = p + m2 = minmn - p + m3 = m - minmn + n1 = p + n2 = n - p + + Q1 = view(Q, 1:m, 1:m1) # full rank portion + Q2 = view(Q, 1:m, m1+1:m2+m1) + R11 = view(R, 1:m1, 1:n1) + R12 = view(R, 1:m1, n1+1:n) + + dA1 = view(dA, 1:m, 1:n1) + dA2 = view(dA, 1:m, (n1 + 1):n) + + dQ, dR = dQR + dQ1 = view(dQ, 1:m, 1:m1) + dQ2 = view(dQ, 1:m, m1+1:m2+m1) + dQ3 = m1+m2+1 < size(dQ, 2) ? view(dQ, 1:m, m1+m2+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0)) + dR11 = view(dR, 1:m1, 1:n1) + dR12 = view(dR, 1:m1, n1+1:n) + dR22 = view(dR, m1+1:m1+m2, n1+1:n) + + # fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need + invR11 = inv(R11) + tmp = Q1' * dA1 * invR11 + Rtmp = tmp + tmp' + diagview(Rtmp) ./= 2 + ltRtmp = view(Rtmp, MatrixAlgebraKit.lowertriangularind(Rtmp)) + ltRtmp .= zero(eltype(Rtmp)) + dR11 .= Rtmp * R11 + dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 + dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12) + dQ2 .= Q1 * (Q1' * dQ2) + if size(Q2, 2) > 0 + dQ2 .+= Q2 * (Q2' * dQ2) + end + if m3 > 0 && size(dQ2, 2) > 0 + # only present for qr_full or rank-deficient qr_compact + Q3 = view(Q, 1:m, m1+m2+1:size(Q, 2)) + dQ2 .+= Q3 * (Q3' * dQ2) + end + if !isempty(dR22) + _, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) + dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2)) + end + return (dQ, dR) +end +#=Ac = MatrixAlgebraKit.copy_input(qr_full, Aval) +QR = MatrixAlgebraKit.initialize_output(qr_full!, Aval, alg.val) +Q, R = qr_full!(Ac, QR, alg.val) +Nval = N.val +copy!(Nval, view(Q, 1:size(Aval, 1), (size(Aval, 2) + 1):size(Aval, 1))) +(m, n) = size(Aval) +minmn = min(m, n) +dQ = zeros(eltype(Aval), (m, m)) +view(dQ, 1:m, (minmn + 1):m) .= dN +MatrixAlgebraKit.qr_fwd(dA, A.val, (Q, R), (dQ, zeros(eltype(R), size(R)))) +dN .= view(dQ, 1:m, (minmn + 1):m) +dA .= zero(eltype(A.val))=# + +function qr_null_pullfwd!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/src/pullfwds/svd.jl b/src/pullfwds/svd.jl new file mode 100644 index 00000000..699de07f --- /dev/null +++ b/src/pullfwds/svd.jl @@ -0,0 +1,21 @@ +function svd_pullfwd!(dA, A, USVᴴ, dUSVᴴ; kwargs...) + U, S, Vᴴ = USVᴴ + dU, dS, dVᴴ = dUSVᴴ + V = adjoint(Vᴴ) + copyto!(dS.diag, diag(real.(U' * dA * V))) + m, n = size(A) + F = one(eltype(S)) ./ (diagview(S)' .- diagview(S)) + G = one(eltype(S)) ./ (diagview(S)' .+ diagview(S)) + diagview(F) .= zero(eltype(F)) + invSdiag = zeros(eltype(S), length(S.diag)) + for i in 1:length(S.diag) + @inbounds invSdiag[i] = inv(diagview(S)[i]) + end + invS = Diagonal(invSdiag) + ∂U = U * (F .* (U' * dA * V * S + S * Vᴴ * dA' * U)) + (diagm(ones(eltype(U), m)) - U*U') * dA * V * invS + ∂V = V * (F .* (S * U' * dA * V + Vᴴ * dA' * U * S)) + (diagm(ones(eltype(V), n)) - V*Vᴴ) * dA' * U * invS + copyto!(dU, ∂U) + adjoint!(dVᴴ, ∂V) + dA .= zero(eltype(A)) + return (dU, dS, dVᴴ) +end diff --git a/test/ad_utils.jl b/test/ad_utils.jl new file mode 100644 index 00000000..11f3b02e --- /dev/null +++ b/test/ad_utils.jl @@ -0,0 +1,26 @@ +function remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' + gaugepart = (gaugepart - gaugepart') / 2 + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_dependence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = V' * ΔV + gaugepart = (gaugepart - gaugepart') / 2 + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end + +precision(::Type{<:Union{Float32,Complex{Float32}}}) = sqrt(eps(Float32)) +precision(::Type{<:Union{Float64,Complex{Float64}}}) = sqrt(eps(Float64)) diff --git a/test/chainrules.jl b/test/chainrules.jl index ba3f0681..76eb84c8 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -6,38 +6,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' - gaugepart = (gaugepart - gaugepart') / 2 - gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ -end -function remove_eiggauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end -function remove_eighgauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D) - ) - gaugepart = V' * ΔV - gaugepart = (gaugepart - gaugepart') / 2 - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end - -precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32)) -precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64)) +include("ad_utils.jl") for f in ( diff --git a/test/cuda/enzyme.jl b/test/cuda/enzyme.jl new file mode 100644 index 00000000..12caad8c --- /dev/null +++ b/test/cuda/enzyme.jl @@ -0,0 +1,326 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using ChainRulesCore +using Enzyme, EnzymeTestUtils +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +is_ci = get(ENV, "CI", "false") == "true" + +ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF64) # Enzyme/#2631 + +include("ad_utils.jl") +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + A = CuArray(randn(rng, T, m, n)) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + @testset for alg in (LAPACK_HouseholderQR(), + LAPACK_HouseholderQR(; positive=true), + ) + #=@testset "forward" begin + @testset "qr_compact" begin + Q, R = qr_compact(A; alg=alg) + ΔQ = randn(rng, T, m, minmn) + ΔR = randn(rng, T, minmn, n) + # run reverse, then forwards, see if we recover + ΔQ2 = copy(ΔQ) + ΔR2 = copy(ΔR) + ΔA = zeros(T, m, n) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔQ .= zero(T) + ΔR .= zero(T) + ΔA2 = copy(ΔA) + ΔQ, ΔR = MatrixAlgebraKit.qr_fwd!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA .= zero(T) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + @test ΔA ≈ ΔA2 atol=atol rtol=rtol + #test_forward(qr_compact, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(ΔQ, ΔR)) + end + #=@testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + ΔN2 = copy(ΔN) + ΔA = zeros(T, m, n) + ΔA = MatrixAlgebraKit.qr_null_pullback!(ΔA, A, N, ΔN) + #test_forward(qr_null, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=ΔN) + end=# + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + ΔQ2 = copy(ΔQ) + ΔR2 = copy(ΔR) + ΔA = zeros(T, m, n) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA2 = copy(ΔA) + ΔQ, ΔR = MatrixAlgebraKit.qr_fwd!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + @test ΔA ≈ ΔA2 atol=atol rtol=rtol + #test_reverse(qr_full, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(ΔQ, ΔR)) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + ΔQ2 = copy(ΔQ) + ΔR2 = copy(ΔR) + ΔA = zeros(T, m, n) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA2 = copy(ΔA) + ΔQ, ΔR = MatrixAlgebraKit.qr_fwd!(ΔA, A, (Q, R), (ΔQ, ΔR)) + ΔA = MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ΔR)) + @test ΔA ≈ ΔA2 atol=atol rtol=rtol + #test_forward(qr_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + end=# + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "qr_compact" begin + ΔQ = CuArray(randn(rng, T, m, minmn)) + ΔR = CuArray(randn(rng, T, minmn, n)) + test_reverse(qr_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + @testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * CuArray(randn(rng, T, minmn, max(0, m - minmn))) + test_reverse(qr_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=ΔN) + end + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = CuArray(randn(rng, T, m, m)) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = CuArray(randn(rng, T, m, n)) + test_reverse(qr_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = CuArray(randn(rng, T, m, r) * randn(rng, T, r, n)) + Q, R = qr_compact(Ard, alg) + ΔQ = CuArray(randn(rng, T, m, minmn)) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = CuArray(randn(rng, T, minmn, n)) + view(ΔR, (r + 1):minmn, :) .= 0 + test_reverse(qr_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + end + end + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + A = CuArray(randn(rng, T, m, n)) + @testset for alg in (LAPACK_HouseholderLQ(), + LAPACK_HouseholderLQ(; positive=true), + ) + #@testset "forward: RT $RT, TA $TA" for RT in (Const,Duplicated,DuplicatedNoNeed), TA in (Const,Duplicated,) + #test_forward(lq_full, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + #test_forward(lq_null, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + #test_forward(lq_compact, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + #end + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "lq_compact" begin + ΔL = CuArray(randn(rng, T, m, minmn)) + ΔQ = CuArray(randn(rng, T, minmn, n)) + test_reverse(lq_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ)) + end + @testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = CuArray(randn(rng, T, max(0, n - minmn), minmn)) * Q + test_reverse(lq_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=ΔNᴴ) + end + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = CuArray(randn(rng, T, n, n)) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = CuArray(randn(rng, T, m, n)) + test_reverse(lq_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ)) + end + @testset "lq_compact -- rank-deficient A" begin + r = minmn - 5 + Ard = CuArray(randn(rng, T, m, r) * randn(rng, T, r, n)) + L, Q = lq_compact(Ard, alg) + ΔL = CuArray(randn(rng, T, m, minmn)) + ΔQ = CuArray(randn(rng, T, minmn, n)) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + test_reverse(lq_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ)) + end + end + end + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = CuArray(randn(rng, T, m, m)) + D, V = eig_full(A) + ΔV = CuArray(randn(rng, complex(T), m, m)) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = CuArray(randn(rng, complex(T), m, m)) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) + @testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,) + #=ΔV2 = copy(ΔV) + ΔD2_ = copy(ΔD2) + ΔA = zeros(T, m, m) + ΔA = MatrixAlgebraKit.eig_pullback!(ΔA, copy(A), (D, V), (ΔD2, ΔV)) + ΔA2 = copy(ΔA) + ΔD2_, ΔV = MatrixAlgebraKit.eig_full_fwd!(ΔA, copy(A), (D, V), (ΔD2, ΔV)) + ΔA = MatrixAlgebraKit.eig_pullback!(ΔA, copy(A), (D, V), (ΔD2_, ΔV)) + @test ΔA ≈ ΔA2 atol=atol rtol=rtol=# # TODO + test_forward(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + end + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(eig_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) + test_reverse(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) + end + end +end + +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = CuArray(randn(rng, T, m, m)) + A = A + A' + D, V = eigh_full(A) + D2 = Diagonal(D) + ΔV = CuArray(randn(rng, T, m, m)) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = CuArray(randn(rng, real(T), m, m)) + ΔD2 = Diagonal(CuArray(randn(rng, real(T), m))) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + LAPACK_Bisection(), + LAPACK_MultipleRelativelyRobustRepresentations(), + ) + @testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,) + test_forward(eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + end + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(eigh_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) + test_reverse(eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) + end + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CuArray(randn(rng, T, m, n)) + minmn = min(m, n) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + ) + #=@testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + test_rewind(svd_compact, RT, Duplicated, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ)) + #test_forward(svd_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end=# # TODO + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "svd_compact" begin + U, S, Vᴴ = svd_compact(A) + ΔU = CuArray(randn(rng, T, m, minmn)) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = CuArray(randn(rng, T, minmn, n)) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + test_reverse(svd_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ)) + end + #= + @testset "svd_full" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, m) + ΔS = randn(rng, real(T), m, n) + ΔVᴴ = randn(rng, T, n, n) + remove_svdgauge_dependence!(view(ΔU, :, 1:minmn), view(ΔVᴴ, 1:minmn, :), view(U, :, 1:minmn), view(S, 1:minmn, 1:minmn), view(Vᴴ, 1:minmn, :); degeneracy_atol=atol) + test_reverse(svd_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ)) + end =# # TODO + end + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CuArray(randn(rng, T, m, n)) + @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + m >= n && + test_reverse(left_polar, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + m <= n && + test_reverse(right_polar, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end + end + end +end + +@timedtestset "Orth and null with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CuArray(randn(rng, T, m, n)) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "left_orth" begin + @testset for kind in (:polar, :qr) + n > m && kind == :polar && continue + test_reverse(left_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(kind=kind,)) + end + end + @testset "right_orth" begin + @testset for kind in (:polar, :lq) + n < m && kind == :polar && continue + test_reverse(right_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(kind=kind,)) + end + end + @testset "left_null" begin + ΔN = left_orth(A; kind=:qr)[1] * CuArray(randn(rng, T, min(m, n), m - min(m, n))) + test_reverse(left_null, RT, (A, TA); fkwargs=(; kind=:qr), output_tangent=ΔN, atol=atol, rtol=rtol) + end + @testset "right_null" begin + ΔNᴴ = CuArray(randn(rng, T, n - min(m, n), min(m, n))) * right_orth(A; kind=:lq)[2] + test_reverse(right_null, RT, (A, TA); fkwargs=(; kind=:lq), output_tangent=ΔNᴴ, atol=atol, rtol=rtol) + end + end + end +end diff --git a/test/cuda/mooncake.jl b/test/cuda/mooncake.jl new file mode 100644 index 00000000..00bf3762 --- /dev/null +++ b/test/cuda/mooncake.jl @@ -0,0 +1,210 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using Mooncake, Mooncake.TestUtils, ChainRulesCore +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using CUDA +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +function remove_svdgauge_depence!(ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' + gaugepart = (gaugepart - gaugepart') / 2 + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_depence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_depence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) + gaugepart = V' * ΔV + gaugepart = (gaugepart - gaugepart') / 2 + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end + +precision(::Type{<:Union{Float32,Complex{Float32}}}) = sqrt(eps(Float32)) +precision(::Type{<:Union{Float64,Complex{Float64}}}) = sqrt(eps(Float64)) + +for f in + (:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, #:eig_full, :eigh_full, + :eigh_full, :svd_compact, :svd_trunc, :left_polar, :right_polar) + copy_f = Symbol(:copy_, f) + f! = Symbol(f, '!') + @eval begin + function $copy_f(input, alg) + if $f === eigh_full + input = (input + input') / 2 + end + return $f(input, alg) + end + function ChainRulesCore.rrule(::typeof($copy_f), input, alg) + output = MatrixAlgebraKit.initialize_output($f!, input, alg) + if $f === eigh_full + input = (input + input') / 2 + else + input = copy(input) + end + output, pb = ChainRulesCore.rrule($f!, input, output, alg) + return output, x -> (NoTangent(), pb(x)[2], NoTangent()) + end + Mooncake.@from_chainrules Mooncake.DefaultCtx Tuple{typeof($copy_f), AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm} false Mooncake.ReverseMode + end +end + +for f in (:eig_full,)#:eigh_full) + copy_f = Symbol(:copy_, f) + f! = Symbol(f, '!') + @eval begin + function $copy_f(input, alg) + if $f === eigh_full + input = (input + input') / 2 + end + return $f(input, alg) + end + end +end + +@timedtestset "QR AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + # qr_compact + atol = rtol = m * n * precision(T) + A = CUDA.randn(rng, T, m, n) + minmn = min(m, n) + alg = CUSOLVER_HouseholderQR(; positive=true) + @testset for f in (copy_qr_compact, copy_qr_null, copy_qr_full) + Mooncake.TestUtils.test_rule(rng, f, A, alg; mode=Mooncake.ReverseMode) + end + + # rank-deficient A + r = minmn - 5 + A = CUDA.randn(rng, T, m, r) * CUDA.randn(rng, T, r, n) + Q, R = qr_compact(A, alg) + Mooncake.TestUtils.test_rule(rng, copy_qr_compact, A, alg; mode=Mooncake.ReverseMode) + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + # lq_compact + atol = rtol = m * n * precision(T) + A = CUDA.randn(rng, T, m, n) + minmn = min(m, n) + alg = CUSOLVER_HouseholderLQ(; positive=true) + @testset for f in (copy_lq_compact, copy_lq_null, copy_lq_full) + Mooncake.TestUtils.test_rule(rng, f, A, alg; mode=Mooncake.ReverseMode) + end + # rank-deficient A + r = minmn - 5 + A = CUDA.randn(rng, T, m, r) * CUDA.randn(rng, T, r, n) + Mooncake.TestUtils.test_rule(rng, copy_lq_compact, A, alg; mode=Mooncake.ReverseMode) + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in (Float64, Float32, ComplexF64) + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = CUDA.randn(rng, T, m, m) + @testset for alg in (CUSOLVER_Simple(), CUSOLVER_Expert()) + Mooncake.TestUtils.test_rule(rng, copy_eig_full, A, alg; mode=Mooncake.ReverseMode) + end +end +@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = CUDA.randn(rng, T, m, m) + A = A + A' + @testset for alg in (CUSOLVER_QRIteration(), CUSOLVER_DivideAndConquer(), CUSOLVER_Bisection(), + CUSOLVER_MultipleRelativelyRobustRepresentations()) + # copy_eigh_full includes a projector onto the Hermitian part of the matrix + Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode=Mooncake.ReverseMode) + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CUDA.randn(rng, T, m, n) + minmn = min(m, n) + U, S, Vᴴ = svd_compact(A) + @testset for alg in (CUSOLVER_QRIteration(), CUSOLVER_DivideAndConquer()) + Mooncake.TestUtils.test_rule(rng, copy_svd_compact, A, alg; mode=Mooncake.ReverseMode) + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + Mooncake.TestUtils.test_rule(rng, copy_svd_trunc, A, truncalg; mode=Mooncake.ReverseMode) + end + truncalg = TruncatedAlgorithm(alg, trunctol(S[1, 1] / 2)) + r = findlast(>=(S[1, 1] / 2), diagview(S)) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + Mooncake.TestUtils.test_rule(rng, copy_svd_trunc, A, truncalg; mode=Mooncake.ReverseMode) + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in (Float64, Float32) #, ComplexF64) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = CUDA.randn(rng, T, m, n) + @testset for alg in PolarViaSVD.((CUSOLVER_QRIteration(), CUSOLVER_DivideAndConquer())) + m >= n && + Mooncake.TestUtils.test_rule(rng, copy_left_polar, A, alg; mode=Mooncake.ReverseMode) + + m <= n && + Mooncake.TestUtils.test_rule(rng, copy_right_polar, A, alg; mode=Mooncake.ReverseMode) + end + end +end + +#= +@timedtestset "Orth and null with eltype $T" for T in (Float64, ComplexF64, Float32) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + config = Zygote.ZygoteRuleConfig() + test_rrule(config, left_orth, A; + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + test_rrule(config, left_orth, A; fkwargs=(; kind=:qr), + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + m >= n && + test_rrule(config, left_orth, A; fkwargs=(; kind=:polar), + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + + ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + test_rrule(config, left_null, A; fkwargs=(; kind=:qr), output_tangent=ΔN, + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + + test_rrule(config, right_orth, A; + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + test_rrule(config, right_orth, A; fkwargs=(; kind=:lq), + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + m <= n && + test_rrule(config, right_orth, A; fkwargs=(; kind=:polar), + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] + test_rrule(config, right_null, A; fkwargs=(; kind=:lq), output_tangent=ΔNᴴ, + atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) + end +end +=# diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 00000000..0bf77aed --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,350 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using ChainRulesCore +using Enzyme, EnzymeTestUtils +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +is_ci = get(ENV, "CI", "false") == "true" + +ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631 +include("ad_utils.jl") + +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + A = randn(rng, T, m, n) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + @testset for alg in (LAPACK_HouseholderQR(), + LAPACK_HouseholderQR(; positive=true), + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "qr_compact" begin + ΔQ = randn(rng, T, m, minmn) + ΔR = randn(rng, T, minmn, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(qr_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR), fdm=fdm) + end + @testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + test_reverse(qr_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=ΔN) + end + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(qr_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR), fdm=fdm) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + test_reverse(qr_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) + end + end + end + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + A = randn(rng, T, m, n) + @testset for alg in (LAPACK_HouseholderLQ(), + LAPACK_HouseholderLQ(; positive=true), + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "lq_compact" begin + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ), fdm=fdm) + end + @testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + test_reverse(lq_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=ΔNᴴ) + end + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn(rng, T, m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ), fdm=fdm) + end + @testset "lq_compact -- rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + L, Q = lq_compact(Ard, alg) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent = (ΔL, ΔQ), fdm=fdm) + end + end + end + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + D, V = eig_full(A) + Ddiag = diagview(D) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(eig_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) + test_reverse(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) + end + @testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,) + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + # broken right now due to Enzyme + #test_reverse(eig_trunc!, RT, (A, TA), ((D, V), TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=Base.RefValue((ΔDtrunc, ΔVtrunc, zero(real(T))))) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + # broken right now due to Enzyme + #test_reverse(eig_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=Base.RefValue((ΔDtrunc, ΔVtrunc, zero(real(T))))) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function copy_eigh_full(A; kwargs...) + A = (A + A')/2 + eigh_full(A; kwargs...) +end + +function copy_eigh_vals(A; kwargs...) + A = (A + A')/2 + eigh_vals(A; kwargs...) +end + +function copy_eigh_trunc!(A; kwargs...) + A = (A + A')/2 + DV = MatrixAlgebraKit.initialize_output(eigh_trunc!, A, kwargs[:alg]) + eigh_trunc!(A, DV; kwargs...) +end + +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + A = A + A' + D, V = eigh_full(A) + D2 = Diagonal(D) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + LAPACK_Bisection(), + LAPACK_MultipleRelativelyRobustRepresentations(), + ) + @testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,) + test_forward(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + end + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(copy_eigh_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) + test_reverse(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) + end + @testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,) + for r in 1:4:m + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + # broken right now due to Enzyme + #test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔDtrunc), copy(ΔVtrunc), zero(real(T)))) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + # broken right now due to Enzyme + #test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔDtrunc), copy(ΔVtrunc), zero(real(T)))) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "svd_compact" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(svd_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ), fdm=fdm) + end + end + @testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,) + @testset "svd_trunc" begin + for r in 1:4:minmn + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + # broken due to Enzyme + #test_reverse(svd_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm=fdm) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + # broken due to Enzyme + #test_reverse(svd_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (copy(U), copy(S), copy(Vᴴ)), (copy(ΔUtrunc), copy(ΔStrunc), copy(ΔVᴴtrunc)), ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + m >= n && + test_reverse(left_polar, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + m <= n && + test_reverse(right_polar, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end + end + end +end + +@timedtestset "Orth and null with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "left_orth" begin + @testset for kind in (:polar, :qr) + n > m && kind == :polar && continue + test_reverse(left_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(kind=kind,)) + end + end + @testset "right_orth" begin + @testset for kind in (:polar, :lq) + n < m && kind == :polar && continue + test_reverse(right_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(kind=kind,)) + end + end + @testset "left_null" begin + ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + test_reverse(left_null, RT, (A, TA); fkwargs=(; kind=:qr), output_tangent=ΔN, atol=atol, rtol=rtol) + end + @testset "right_null" begin + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] + test_reverse(right_null, RT, (A, TA); fkwargs=(; kind=:lq), output_tangent=ΔNᴴ, atol=atol, rtol=rtol) + end + end + end +end + diff --git a/test/mooncake.jl b/test/mooncake.jl new file mode 100644 index 00000000..4263fd25 --- /dev/null +++ b/test/mooncake.jl @@ -0,0 +1,400 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using Mooncake, Mooncake.TestUtils, ChainRulesCore +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +function Mooncake.increment!!(x::Tuple{Matrix{T}, Mooncake.Tangent{@NamedTuple{diag::Vector{T}}}, Matrix{T}}, y::Tuple{Matrix{T}, Mooncake.Tangent{@NamedTuple{diag::Vector{T}}}, Matrix{T}}) where {T<:Real} + return (Mooncake.increment!!(x[1], y[1]), Mooncake.increment!!(x[2], y[2]), Mooncake.increment!!(x[3], y[3])) +end +function Mooncake.increment!!(x::Tuple{Mooncake.Tangent{@NamedTuple{diag::Vector{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}}, Matrix{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}, y::Tuple{Mooncake.Tangent{@NamedTuple{diag::Vector{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}}, Matrix{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}) where {T<:Real} + return (Mooncake.increment!!(x[1], y[1]), Mooncake.increment!!(x[2], y[2])) +end +function Mooncake.increment!!(x::Tuple{Mooncake.Tangent{@NamedTuple{diag::Vector{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}}, Matrix{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}, y::Tuple{Mooncake.Tangent{@NamedTuple{diag::Vector{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}}}, Matrix{Mooncake.Tangent{@NamedTuple{re::T, im::T}}}, Vector{T}}) where {T<:Real} + return (Mooncake.increment!!(x[1], y[1]), Mooncake.increment!!(x[2], y[2])) +end + +include("ad_utils.jl") + +make_mooncake_tangent(ΔAelem::T) where {T<:Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) +make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA +make_mooncake_tangent(ΔA::Matrix{T}) where {T<:Complex} = map(make_mooncake_tangent, ΔA) +function make_mooncake_tangent(ΔD::Diagonal{T}) where {T<:Real} + return Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) +end +function make_mooncake_tangent(ΔD::Diagonal{T}) where {T<:Complex} + diag_tangent = map(make_mooncake_tangent, diagview(ΔD)) + return Mooncake.build_tangent(typeof(ΔD), diag_tangent) +end + +ETs = (Float64, Float32, ComplexF64, ComplexF32) + +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in (LAPACK_HouseholderQR(), + LAPACK_HouseholderQR(; positive=true), + ) + @testset "qr_compact" begin + Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode=Mooncake.ReverseMode, output_tangent = dN, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + dQ = make_mooncake_tangent(ΔQ) + dR = make_mooncake_tangent(ΔR) + dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR) + Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode=Mooncake.ReverseMode, output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + dQ = make_mooncake_tangent(ΔQ) + dR = make_mooncake_tangent(ΔR) + dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR) + Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; mode=Mooncake.ReverseMode, output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol) + end + end + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in (LAPACK_HouseholderLQ(), + LAPACK_HouseholderLQ(; positive=true), + ) + @testset "lq_compact" begin + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol, output_tangent = dLQ) + end + @testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode=Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn(rng, T, m, n) + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode=Mooncake.ReverseMode, output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "lq_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + L, Q = lq_compact(Ard, alg) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode=Mooncake.ReverseMode, output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol) + end + end + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + DV = eig_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + + dD = make_mooncake_tangent(ΔD2) + dV = make_mooncake_tangent(ΔV) + dDV = Mooncake.build_tangent(typeof((ΔD2,ΔV)), dD, dV) + # compute the dA corresponding to the above dD, dV + @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) + @testset "eig_full" begin + Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode=Mooncake.ReverseMode, output_tangent = dDV, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "eig_vals" begin + Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; atol=atol, rtol=rtol, is_primitive=false) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc,ΔVtrunc,zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode=Mooncake.ReverseMode, output_tangent=dDVtrunc, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc,ΔVtrunc,zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode=Mooncake.ReverseMode, output_tangent=dDVtrunc, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function copy_eigh_full(A, alg; kwargs...) + A = (A + A')/2 + eigh_full(A, alg; kwargs...) +end + +function copy_eigh_vals(A, alg; kwargs...) + A = (A + A')/2 + eigh_vals(A, alg; kwargs...) +end + +function copy_eigh_trunc(A, alg; kwargs...) + A = (A + A')/2 + eigh_trunc(A, alg; kwargs...) +end + +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + A = A + A' + D, V = eigh_full(A) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol=atol) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + dD = make_mooncake_tangent(ΔD2) + dV = make_mooncake_tangent(ΔV) + dDV = Mooncake.build_tangent(typeof((ΔD2,ΔV)), dD, dV) + Ddiag = diagview(D) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + LAPACK_Bisection(), + LAPACK_MultipleRelativelyRobustRepresentations(), + ) + @testset "eigh_full" begin + Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode=Mooncake.ReverseMode, output_tangent=dDV, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "eigh_vals" begin + Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; is_primitive=false, atol=atol, rtol=rtol) + end + @testset "eigh_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc,ΔVtrunc,zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode=Mooncake.ReverseMode, output_tangent=dDVtrunc, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc,ΔVtrunc,zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode=Mooncake.ReverseMode, output_tangent=dDVtrunc, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function dummy_svd_trunc(A, args...; kwargs...) + U, S, Vᴴ, ϵ = svd_trunc(A, args...; kwargs...) + return U, S, Vᴴ +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in (LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + ) + @testset "svd_compact" begin + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + dS = make_mooncake_tangent(ΔS2) + dU = make_mooncake_tangent(ΔU) + dVᴴ = make_mooncake_tangent(ΔVᴴ) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔU,ΔS2,ΔVᴴ)), dU, dS, dVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴ, is_primitive=false, atol=atol, rtol=rtol) + end + @testset "svd_vals" begin + Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; is_primitive=false, atol=atol, rtol=rtol) + end + @testset "svd_trunc" begin + @testset for r in 1:4:minmn + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + dStrunc = make_mooncake_tangent(ΔStrunc) + dUtrunc = make_mooncake_tangent(ΔUtrunc) + dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) + dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU,ΔS2,ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, dummy_svd_trunc, copy(A), truncalg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴerr, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + dStrunc = make_mooncake_tangent(ΔStrunc) + dUtrunc = make_mooncake_tangent(ΔUtrunc) + dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) + dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU,ΔS2,ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, dummy_svd_trunc, copy(A), truncalg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴerr, atol=atol, rtol=rtol, is_primitive=false) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) + m >= n && + Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol) + + m <= n && + Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol) + end + end +end + +@timedtestset "Orth and null with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + + Mooncake.TestUtils.test_rule(rng, (X->left_orth(X; kind=:qr)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + if m >= n + Mooncake.TestUtils.test_rule(rng, (X->left_orth(X; kind=:polar)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + end + + ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, (X->left_null(X; kind=:qr)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false, output_tangent = dN) + + Mooncake.TestUtils.test_rule(rng, (X->right_orth(X; kind=:lq)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + + if m <= n + Mooncake.TestUtils.test_rule(rng, (X->right_orth(X; kind=:polar)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + end + + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, (X->right_null(X; kind=:lq)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false, output_tangent = dNᴴ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ec255538..2f39856f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,12 @@ if !is_buildkite @safetestset "Image and Null Space" begin include("orthnull.jl") end + @safetestset "Mooncake" begin + include("mooncake.jl") + end + @safetestset "Enzyme" begin + include("enzyme.jl") + end @safetestset "ChainRules" begin include("chainrules.jl") end @@ -81,6 +87,12 @@ if CUDA.functional() @safetestset "CUDA Image and Null Space" begin include("cuda/orthnull.jl") end + #=@safetestset "CUDA Mooncake" begin + include("cuda/mooncake.jl") + end + @safetestset "CUDA Enzyme" begin + include("cuda/enzyme.jl") + end=# end using AMDGPU @@ -106,6 +118,9 @@ if AMDGPU.functional() @safetestset "AMDGPU Image and Null Space" begin include("amd/orthnull.jl") end + #=@safetestset "AMDGPU Enzyme" begin + include("amd/enzyme.jl") + end=# end using GenericLinearAlgebra From bf9324cf4779829303ac636f1e8157c22770f5dd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 22 Oct 2025 17:01:48 +0200 Subject: [PATCH 2/6] Add dummy reverse modes for truncated methods --- .../MatrixAlgebraKitEnzymeExt.jl | 54 ++++++++------ src/MatrixAlgebraKit.jl | 12 ++-- src/implementations/svd.jl | 54 ++++++++++---- src/pullfwds/polar.jl | 2 - src/{pullfwds => pushforwards}/eig.jl | 2 +- src/{pullfwds => pushforwards}/eigh.jl | 2 +- src/{pullfwds => pushforwards}/lq.jl | 2 +- src/pushforwards/polar.jl | 2 + src/{pullfwds => pushforwards}/qr.jl | 2 +- src/{pullfwds => pushforwards}/svd.jl | 2 +- test/enzyme.jl | 72 ++++++++++++++----- 11 files changed, 139 insertions(+), 67 deletions(-) delete mode 100644 src/pullfwds/polar.jl rename src/{pullfwds => pushforwards}/eig.jl (84%) rename src/{pullfwds => pushforwards}/eigh.jl (88%) rename src/{pullfwds => pushforwards}/lq.jl (92%) create mode 100644 src/pushforwards/polar.jl rename src/{pullfwds => pushforwards}/qr.jl (94%) rename src/{pullfwds => pushforwards}/svd.jl (93%) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index db577784..6d6fa4c3 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -1,7 +1,7 @@ module MatrixAlgebraKitEnzymeExt using MatrixAlgebraKit -using MatrixAlgebraKit: diagview, inv_safe +using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc! using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pullfwd!, lq_pullfwd! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pullfwd!, lq_null_pullfwd! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pullfwd!, eigh_pullfwd! @@ -321,7 +321,7 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, A::Annotation{<:AbstractMatrix}, USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, ϵ::Annotation{Vector{T}}, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; kwargs..., ) where {RT, T<:Real} # form cache if needed @@ -350,7 +350,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, A::Annotation{<:AbstractMatrix}, USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}}, ϵ::Annotation{Vector{T}}, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; kwargs...) where {RT, T<:Real} cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache U, S, Vᴴ = cache_USVᴴ @@ -363,7 +363,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, make_zero!(USVᴴ.dval) end if !isa(ϵ, Const) - ϵ.dval .= zero(T) + make_zero!(ϵ.dval) end return (nothing, nothing, nothing, nothing) end @@ -476,16 +476,17 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Type{RT}, A::Annotation{<:AbstractMatrix}, DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; kwargs..., - ) where {RT} + ) where {RT, T} # form cache if needed cache_A = copy(A.val) - eigh_full!(A.val, DV.val, alg.val.alg) + MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg) cache_DV = copy.(DV.val) DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc) - ϵ = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind) - primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ) : nothing + ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind) + primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing shadow_DV = if !isa(A, Const) && !isa(DV, Const) dD, dV = DV.dval dDtrunc = Diagonal(diagview(dD)[ind]) @@ -494,17 +495,18 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, else (nothing, nothing) end - shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., zero(ϵ)) : nothing + shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind)) end function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(eigh_trunc!)}, - dret, + ::Type{RT}, cache, A::Annotation{<:AbstractMatrix}, DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; - kwargs...) + kwargs...) where {RT, T} cache_A, cache_DV, cache_dDVtrunc, ind = cache D, V = cache_DV dD, dV = cache_dDVtrunc @@ -515,24 +517,28 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, if !isa(DV, Const) make_zero!(DV.dval) end + if !isa(ϵ, Const) + make_zero!(ϵ.dval) + end return (nothing, nothing, nothing, nothing) end -#= + function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(eig_trunc!)}, ::Type{RT}, A::Annotation{<:AbstractMatrix}, DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}; + ϵ::Annotation{Vector{T}}, + alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; kwargs..., - ) where {RT} + ) where {RT, T} # form cache if needed cache_A = copy(A.val) eig_full!(A.val, DV.val, alg.val.alg) cache_DV = copy.(DV.val) DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc) - ϵ = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind) - primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ) : nothing + ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind) + primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing shadow_DV = if !isa(A, Const) && !isa(DV, Const) dD, dV = DV.dval dDtrunc = Diagonal(diagview(dD)[ind]) @@ -541,30 +547,34 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, else (nothing, nothing) end - shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., zero(ϵ)) : nothing + shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind)) end function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(eig_trunc!)}, - dret, + ::Type{RT}, cache, A::Annotation{<:AbstractMatrix}, DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}}, + ϵ::Annotation{Vector{T}}, alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm}; - kwargs...) + kwargs...) where {RT, T} cache_A, cache_DV, cache_dDVtrunc, ind = cache D, V = cache_DV dD, dV = cache_dDVtrunc if !isa(A, Const) && !isa(DV, Const) A.dval .= zero(eltype(A.val)) - A.dval .= MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, (D, V), (dD, dV), ind; kwargs...) + A.dval .= MatrixAlgebraKit.eig_pullback!(A.dval, A.val, (D, V), (dD, dV), ind; kwargs...) end if !isa(DV, Const) make_zero!(DV.dval) end + if !isa(ϵ, Const) + make_zero!(ϵ.dval) + end return (nothing, nothing, nothing, nothing) end -=# + function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(eigh_full!)}, ::Type{RT}, diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 3cda864c..1d2f6f61 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -112,11 +112,11 @@ include("pullbacks/eigh.jl") include("pullbacks/svd.jl") include("pullbacks/polar.jl") -include("pullfwds/qr.jl") -include("pullfwds/lq.jl") -include("pullfwds/eig.jl") -include("pullfwds/eigh.jl") -include("pullfwds/polar.jl") -include("pullfwds/svd.jl") +include("pushforwards/qr.jl") +include("pushforwards/lq.jl") +include("pushforwards/eig.jl") +include("pushforwards/eigh.jl") +include("pushforwards/polar.jl") +include("pushforwards/svd.jl") end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 50f352b0..fed36cd1 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -89,7 +89,7 @@ end function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm) return similar(A, real(eltype(A)), (min(size(A)...),)) end -function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm) +function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm) return initialize_output(svd_compact!, A, alg.alg) end @@ -333,25 +333,46 @@ function _gpu_gesvdj!( ) throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) end +function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) + m, n = size(A) + m ≥ n && return _gpu_gesvd!(A, S, U, Vᴴ) + # both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration) + # if this condition is not met, do the SVD via adjoint + minmn = min(m, n) + Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') + Uᴴ = similar(U') + V = similar(Vᴴ') + if size(U) == (m, m) + _gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) + else + _gpu_gesvd!(Aᴴ, S, V, Uᴴ) + end + length(U) > 0 && adjoint!(U, Uᴴ) + length(Vᴴ) > 0 && adjoint!(Vᴴ, V) + return U, S, Vᴴ +end + # GPU SVD implementation -function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) +function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_full!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ fill!(S, zero(eltype(S))) m, n = size(A) minmn = min(m, n) + if minmn == 0 + one!(U) + zero!(S) + one!(Vᴴ) + return USVᴴ + end if alg isa GPU_QRIteration isempty(alg.kwargs) || - throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) - _gpu_gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) + @warn "GPU_QRIteration does not accept any keyword arguments" + _gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) - # elseif alg isa LAPACK_Bisection - # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) - # elseif alg isa LAPACK_Jacobi - # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) else throw(ArgumentError("Unsupported SVD algorithm")) end @@ -368,16 +389,21 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make this controllable using a `gaugefix` keyword argument gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) - return first(truncate(svd_trunc!, USVᴴ, alg.trunc)) + # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong + USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + Strunc = diagview(USVᴴtrunc[2]) + # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum + ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this? + return USVᴴtrunc..., ϵ end -function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) +function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ if alg isa GPU_QRIteration isempty(alg.kwargs) || - throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) - _gpu_gesvd!(A, S.diag, U, Vᴴ) + @warn "GPU_QRIteration does not accept any keyword arguments" + _gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi @@ -397,8 +423,8 @@ function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) if alg isa GPU_QRIteration isempty(alg.kwargs) || - throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) - _gpu_gesvd!(A, S, U, Vᴴ) + @warn "GPU_QRIteration does not accept any keyword arguments" + _gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ) elseif alg isa GPU_SVDPolar _gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...) elseif alg isa GPU_Jacobi diff --git a/src/pullfwds/polar.jl b/src/pullfwds/polar.jl deleted file mode 100644 index 5e47da5d..00000000 --- a/src/pullfwds/polar.jl +++ /dev/null @@ -1,2 +0,0 @@ -function left_polar_pullfwd! end -function right_polar_pullfwd! end diff --git a/src/pullfwds/eig.jl b/src/pushforwards/eig.jl similarity index 84% rename from src/pullfwds/eig.jl rename to src/pushforwards/eig.jl index 06d2c9e7..8a6808e5 100644 --- a/src/pullfwds/eig.jl +++ b/src/pushforwards/eig.jl @@ -1,4 +1,4 @@ -function eig_pullfwd!(dA, A, DV, dDV; kwargs...) +function eig_pushforward!(dA, A, DV, dDV; kwargs...) D, V = DV dD, dV = dDV ∂K = inv(V) * dA * V diff --git a/src/pullfwds/eigh.jl b/src/pushforwards/eigh.jl similarity index 88% rename from src/pullfwds/eigh.jl rename to src/pushforwards/eigh.jl index 91ba0e91..c050a913 100644 --- a/src/pullfwds/eigh.jl +++ b/src/pushforwards/eigh.jl @@ -1,4 +1,4 @@ -function eigh_pullfwd!(dA, A, DV, dDV; kwargs...) +function eigh_pushforward!(dA, A, DV, dDV; kwargs...) tmpV = V \ dA ∂K = tmpV * V ∂Kdiag = diag(∂K) diff --git a/src/pullfwds/lq.jl b/src/pushforwards/lq.jl similarity index 92% rename from src/pullfwds/lq.jl rename to src/pushforwards/lq.jl index 02f6b253..224d6b7c 100644 --- a/src/pullfwds/lq.jl +++ b/src/pushforwards/lq.jl @@ -1,4 +1,4 @@ -function lq_pullfwd!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) +function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) L, Q = LQ m = size(L, 1) n = size(Q, 2) diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl new file mode 100644 index 00000000..803771df --- /dev/null +++ b/src/pushforwards/polar.jl @@ -0,0 +1,2 @@ +function left_polar_pushforward! end +function right_polar_pushforward! end diff --git a/src/pullfwds/qr.jl b/src/pushforwards/qr.jl similarity index 94% rename from src/pullfwds/qr.jl rename to src/pushforwards/qr.jl index 9b60842f..9bf7f523 100644 --- a/src/pullfwds/qr.jl +++ b/src/pushforwards/qr.jl @@ -1,4 +1,4 @@ -function qr_pullfwd!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) +function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) Q, R = QR m = size(Q, 1) n = size(R, 2) diff --git a/src/pullfwds/svd.jl b/src/pushforwards/svd.jl similarity index 93% rename from src/pullfwds/svd.jl rename to src/pushforwards/svd.jl index 699de07f..2a1f162a 100644 --- a/src/pullfwds/svd.jl +++ b/src/pushforwards/svd.jl @@ -1,4 +1,4 @@ -function svd_pullfwd!(dA, A, USVᴴ, dUSVᴴ; kwargs...) +function svd_pushforward!(dA, A, USVᴴ, dUSVᴴ; kwargs...) U, S, Vᴴ = USVᴴ dU, dS, dVᴴ = dUSVᴴ V = adjoint(Vᴴ) diff --git a/test/enzyme.jl b/test/enzyme.jl index 0bf77aed..85fb51b9 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -113,6 +113,19 @@ end end end +function MatrixAlgebraKit.eig_trunc!(A, DV, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T} + D, V = eig_full!(A, DV, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, (D, V), alg.trunc) + ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(D), ind) + return DVtrunc..., ϵ +end +function dummy_eig_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T} + Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.eig_trunc, A) + DV = MatrixAlgebraKit.initialize_output(eig_trunc!, A, alg) + Dtrunc, Vtrunc, ϵ = MatrixAlgebraKit.eig_trunc!(Ac, DV, ϵ, alg) + return Dtrunc, Vtrunc, ϵ +end + @timedtestset "EIG AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -129,7 +142,7 @@ end test_reverse(eig_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) test_reverse(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) end - @testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc) @@ -137,8 +150,8 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - # broken right now due to Enzyme - #test_reverse(eig_trunc!, RT, (A, TA), ((D, V), TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=Base.RefValue((ΔDtrunc, ΔVtrunc, zero(real(T))))) + ϵ = [zero(real(T))] + test_reverse(dummy_eig_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))])) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -149,8 +162,8 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - # broken right now due to Enzyme - #test_reverse(eig_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=Base.RefValue((ΔDtrunc, ΔVtrunc, zero(real(T))))) + ϵ = [zero(real(T))] + test_reverse(dummy_eig_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))])) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -168,10 +181,18 @@ function copy_eigh_vals(A; kwargs...) eigh_vals(A; kwargs...) end -function copy_eigh_trunc!(A; kwargs...) +function MatrixAlgebraKit.eigh_trunc!(A, DV, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T} + D, V = eigh_full!(A, DV, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate(eigh_trunc!, (D, V), alg.trunc) + ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(D), ind) + return DVtrunc..., ϵ +end +function dummy_eigh_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T} A = (A + A')/2 - DV = MatrixAlgebraKit.initialize_output(eigh_trunc!, A, kwargs[:alg]) - eigh_trunc!(A, DV; kwargs...) + Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.eigh_trunc, A) + DV = MatrixAlgebraKit.initialize_output(eigh_trunc!, A, alg) + Dtrunc, Vtrunc, ϵ = MatrixAlgebraKit.eigh_trunc!(Ac, DV, ϵ, alg) + return Dtrunc, Vtrunc, ϵ end @timedtestset "EIGH AD Rules with eltype $T" for T in ETs @@ -194,11 +215,11 @@ end @testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,) test_forward(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) end - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) test_reverse(copy_eigh_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) test_reverse(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) end - @testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) for r in 1:4:m Ddiag = diagview(D) truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) @@ -207,8 +228,8 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - # broken right now due to Enzyme - #test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔDtrunc), copy(ΔVtrunc), zero(real(T)))) + ϵ = [zero(real(T))] + test_reverse(dummy_eigh_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))])) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -220,8 +241,8 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - # broken right now due to Enzyme - #test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔDtrunc), copy(ΔVtrunc), zero(real(T)))) + ϵ = [zero(real(T))] + test_reverse(dummy_eigh_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))])) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -229,6 +250,19 @@ end end end +function MatrixAlgebraKit.svd_trunc!(A, USVᴴ, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T} + U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(S), ind) + return USVᴴtrunc..., ϵ +end +function dummy_svd_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T} + Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.svd_trunc, A) + USVᴴ = MatrixAlgebraKit.initialize_output(svd_trunc!, A, alg) + Utrunc, Strunc, Vᴴtrunc, ϵ = MatrixAlgebraKit.svd_trunc!(Ac, USVᴴ, ϵ, alg) + return Utrunc, Strunc, Vᴴtrunc, ϵ +end + @timedtestset "SVD AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -239,7 +273,7 @@ end @testset for alg in (LAPACK_QRIteration(), LAPACK_DivideAndConquer(), ) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) @testset "svd_compact" begin U, S, Vᴴ = svd_compact(A) ΔU = randn(rng, T, m, minmn) @@ -250,7 +284,7 @@ end test_reverse(svd_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ), fdm=fdm) end end - @testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) @testset "svd_trunc" begin for r in 1:4:minmn U, S, Vᴴ = svd_compact(A) @@ -269,7 +303,8 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) # broken due to Enzyme - #test_reverse(svd_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm=fdm) + ϵ = [zero(real(T))] + test_reverse(dummy_svd_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), fdm=fdm) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -290,7 +325,8 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) # broken due to Enzyme - #test_reverse(svd_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm) + ϵ = [zero(real(T))] + test_reverse(dummy_svd_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), fdm=fdm) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (copy(U), copy(S), copy(Vᴴ)), (copy(ΔUtrunc), copy(ΔStrunc), copy(ΔVᴴtrunc)), ind) dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) From d8e6d6f1ccc10f7e9c44dec1c2768d7b42d61fd3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 24 Oct 2025 10:31:13 +0200 Subject: [PATCH 3/6] Incremental --- .../MatrixAlgebraKitEnzymeExt.jl | 26 +++---- .../MatrixAlgebraKitMooncakeExt.jl | 60 +++++++++------- src/pushforwards/lq.jl | 2 +- src/pushforwards/qr.jl | 2 +- src/pushforwards/svd.jl | 24 +++++-- test/enzyme.jl | 37 +++++++--- test/mooncake.jl | 69 ++++++++++++------- test/runtests.jl | 10 +-- 8 files changed, 148 insertions(+), 82 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 6d6fa4c3..0c2330a2 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -2,10 +2,10 @@ module MatrixAlgebraKitEnzymeExt using MatrixAlgebraKit using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc! -using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pullfwd!, lq_pullfwd! -using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pullfwd!, lq_null_pullfwd! -using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pullfwd!, eigh_pullfwd! -using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pullfwd!, right_polar_pullfwd! +using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward! using Enzyme using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules @@ -15,13 +15,13 @@ using LinearAlgebra # two-argument factorizations like LQ, QR, EIG -for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pullfwd!), - (qr_compact!, qr_pullback!, qr_pullfwd!), - (lq_full!, lq_pullback!, lq_pullfwd!), - (lq_compact!, lq_pullback!, lq_pullfwd!), - (eig_full!, eig_pullback!, eig_pullfwd!), - (left_polar!, left_polar_pullback!, left_polar_pullfwd!), - (right_polar!, right_polar_pullback!, right_polar_pullfwd!), +for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!), + (qr_compact!, qr_pullback!, qr_pushforward!), + (lq_full!, lq_pullback!, lq_pushforward!), + (lq_compact!, lq_pullback!, lq_pushforward!), + (eig_full!, eig_pullback!, eig_pushforward!), + (left_polar!, left_polar_pullback!, left_polar_pushforward!), + (right_polar!, right_polar_pullback!, right_polar_pushforward!), ) @eval begin function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, @@ -95,8 +95,8 @@ for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pullfwd!), end end -for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pullfwd!), - (lq_null!, lq_null_pullback!, lq_null_pullfwd!), +for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pushforward!), + (lq_null!, lq_null_pullback!, lq_null_pushforward!), ) @eval begin function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 8af2a95f..76af0dc9 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -4,22 +4,22 @@ using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit using MatrixAlgebraKit: inv_safe, diagview -using MatrixAlgebraKit: svd_pullfwd! -using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pullfwd!, lq_pullfwd! -using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pullfwd!, lq_null_pullfwd! -using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pullfwd!, eigh_pullfwd! -using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pullfwd!, right_polar_pullfwd! +using MatrixAlgebraKit: svd_pushforward! +using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward! using LinearAlgebra # two-argument factorizations like LQ, QR, EIG -for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pullfwd!, :dqr_adjoint), - (qr_compact!, qr_pullback!, qr_pullfwd!, :dqr_adjoint), - (lq_full!, lq_pullback!, lq_pullfwd!, :dlq_adjoint), - (lq_compact!, lq_pullback!, lq_pullfwd!, :dlq_adjoint), - (eig_full!, eig_pullback!, eig_pullfwd!, :deig_adjoint), - (eigh_full!, eigh_pullback!, eigh_pullfwd!, :deigh_adjoint), - (left_polar!, left_polar_pullback!, left_polar_pullfwd!, :dleft_polar_adjoint), - (right_polar!, right_polar_pullback!, right_polar_pullfwd!, :dright_polar_adjoint), +for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoint), + (qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint), + (lq_full!, lq_pullback!, lq_pushforward!, :dlq_adjoint), + (lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint), + (eig_full!, eig_pullback!, eig_pushforward!, :deig_adjoint), + (eigh_full!, eigh_pullback!, eigh_pushforward!, :deigh_adjoint), + (left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint), + (right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint), ) @eval begin @@ -55,14 +55,14 @@ for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pullfwd!, :dqr_adjoint), end end -for (f, pb, pf, adj) in ((qr_null!, qr_null_pullback!, qr_null_pullfwd!, :dqr_null_adjoint), - (lq_null!, lq_null_pullback!, lq_null_pullfwd!, :dlq_null_adjoint), - ) +for (f, f_full, pb, pf, adj) in ((qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint), + (lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint), + ) @eval begin @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) A, dA = arrayify(A_dA) - Ac = MatrixAlgebraKit.copy_input(lq_full, A) + Ac = MatrixAlgebraKit.copy_input($f_full, A) arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg)) arg = $f(Ac, arg, Mooncake.primal(alg_dalg)) function $adj(::Mooncake.NoRData) @@ -72,7 +72,16 @@ for (f, pb, pf, adj) in ((qr_null!, qr_null_pullback!, qr_null_pullfwd!, :dqr_nu end return arg_darg, $adj end - #forward mode not implemented yet + @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, arg_darg::Dual{<:AbstractMatrix}, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) + A, dA = arrayify(A_dA) + Ac = MatrixAlgebraKit.copy_input($f_full, A) + arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg)) + arg = $f(Ac, arg, Mooncake.primal(alg_dalg)) + $pf(dA, A, arg, darg; kwargs...) + dA .= zero(dA) + return arg_darg + end end end @@ -145,7 +154,7 @@ function Mooncake.frule!!(::Dual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::Dua D, dD = arrayify(DV[1], dDV[1]) V, dV = arrayify(DV[2], dDV[2]) (D, V) = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...) - (dD, dV) = eigh_pullfwd!(dA, A, (D, V), (dD, dV); kwargs...) + (dD, dV) = eigh_pushforward!(dA, A, (D, V), (dD, dV); kwargs...) return Mooncake.Dual(DV, dDV) end =# @@ -209,10 +218,13 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal)) vdU = view(dU, :, 1:minmn) vdS = Diagonal(diagview(dS)[1:minmn]) vdVᴴ = view(dVᴴ, 1:minmn, :) - dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (vdU, vdS, vdVᴴ)) + dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() end + dU .= zero(dU) + dS .= zero(dS) + dVᴴ .= zero(dVᴴ) return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint end @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} @@ -228,10 +240,10 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal)) # update tangents U_, S_, Vᴴ_ = USVᴴ dU_, dS_, dVᴴ_ = dUSVᴴ - U, dU = arrayify(U_, dU_) - S, dS = arrayify(S_, dS_) - Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_) - (dU, dS, dVᴴ) = svd_pullfwd!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...) + U, dU = arrayify(U_, dU_) + S, dS = arrayify(S_, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_) + (dU, dS, dVᴴ) = svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...) return USVᴴ_dUSVᴴ end end diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl index 224d6b7c..1a220628 100644 --- a/src/pushforwards/lq.jl +++ b/src/pushforwards/lq.jl @@ -55,4 +55,4 @@ function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pull return (dQ, dR) end -function lq_null_pullfwd!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end +function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/src/pushforwards/qr.jl b/src/pushforwards/qr.jl index 9bf7f523..c2068337 100644 --- a/src/pushforwards/qr.jl +++ b/src/pushforwards/qr.jl @@ -66,4 +66,4 @@ MatrixAlgebraKit.qr_fwd(dA, A.val, (Q, R), (dQ, zeros(eltype(R), size(R)))) dN .= view(dQ, 1:m, (minmn + 1):m) dA .= zero(eltype(A.val))=# -function qr_null_pullfwd!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) end +function qr_null_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index 2a1f162a..f6cf0c41 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -1,19 +1,31 @@ -function svd_pushforward!(dA, A, USVᴴ, dUSVᴴ; kwargs...) +function svd_pushforward!(dA, A, USVᴴ, dUSVᴴ; + tol::Real = default_pullback_gaugetol(USVᴴ[2]), + rank_atol::Real = tol, + degeneracy_atol::Real = tol, + gauge_atol::Real = tol + ) U, S, Vᴴ = USVᴴ dU, dS, dVᴴ = dUSVᴴ V = adjoint(Vᴴ) - copyto!(dS.diag, diag(real.(U' * dA * V))) + UdAV = U' * dA * V + copyto!(diagview(dS), diag(real.(UdAV))) m, n = size(A) F = one(eltype(S)) ./ (diagview(S)' .- diagview(S)) G = one(eltype(S)) ./ (diagview(S)' .+ diagview(S)) diagview(F) .= zero(eltype(F)) - invSdiag = zeros(eltype(S), length(S.diag)) - for i in 1:length(S.diag) + invSdiag = zeros(eltype(S), length(diagview(S))) + for i in 1:length(diagview(S)) @inbounds invSdiag[i] = inv(diagview(S)[i]) end invS = Diagonal(invSdiag) - ∂U = U * (F .* (U' * dA * V * S + S * Vᴴ * dA' * U)) + (diagm(ones(eltype(U), m)) - U*U') * dA * V * invS - ∂V = V * (F .* (S * U' * dA * V + Vᴴ * dA' * U * S)) + (diagm(ones(eltype(V), n)) - V*Vᴴ) * dA' * U * invS + #∂U = U * (F .* (U' * dA * V * S + S * Vᴴ * dA' * U)) + (LinearAlgebra.diagm(ones(eltype(U), m)) - U*U') * dA * V * invS + #∂V = V * (F .* (S * U' * dA * V + Vᴴ * dA' * U * S)) + (LinearAlgebra.diagm(ones(eltype(V), n)) - V*Vᴴ) * dA' * U * invS + hUdAV = F .* project_hermitian(UdAV) + aUdAV = G .* project_antihermitian(UdAV) + ∂U = U * (hUdAV + aUdAV) + ∂U += (LinearAlgebra.diagm(ones(eltype(U), m)) - U*U') * dA * V * invS + ∂V = V * (hUdAV - aUdAV) + ∂V += (LinearAlgebra.diagm(ones(eltype(U), n)) - V*V') * dA' * U * invS copyto!(dU, ∂U) adjoint!(dVᴴ, ∂V) dA .= zero(eltype(A)) diff --git a/test/enzyme.jl b/test/enzyme.jl index 85fb51b9..b9b16c45 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -9,9 +9,10 @@ using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! is_ci = get(ENV, "CI", "false") == "true" -ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631 +#ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631 +ETs = (Float64, ComplexF64) include("ad_utils.jl") - +#= @timedtestset "QR AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -112,7 +113,7 @@ end end end end - +=# function MatrixAlgebraKit.eig_trunc!(A, DV, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T} D, V = eig_full!(A, DV, alg.alg) DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, (D, V), alg.trunc) @@ -133,12 +134,20 @@ end A = randn(rng, T, m, m) D, V = eig_full(A) Ddiag = diagview(D) + ΔA = randn(rng, T, m, m) ΔV = randn(rng, complex(T), m, m) ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=atol) ΔD = randn(rng, complex(T), m, m) ΔD2 = Diagonal(randn(rng, complex(T), m)) @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + # make A hermitian + Ah = (A + A')/2 + test_forward((Ah; kwargs...)->eig_full(A; kwargs...)[1], RT, Duplicated(A, ΔA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + test_forward((Ah; kwargs...)->eig_full(A; kwargs...)[2], RT, Duplicated(A, ΔA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + #test_forward(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + end + #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) test_reverse(eig_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) test_reverse(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) end @@ -167,7 +176,7 @@ end dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end + end=# end end @@ -213,9 +222,11 @@ end LAPACK_MultipleRelativelyRobustRepresentations(), ) @testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,) + test_forward((A; kwargs...)->copy_eigh_full(A; kwargs...)[1], RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + test_forward((A; kwargs...)->copy_eigh_full(A; kwargs...)[2], RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) test_forward(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) end - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) test_reverse(copy_eigh_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) test_reverse(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) end @@ -246,7 +257,7 @@ end dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end + end=# end end @@ -273,7 +284,15 @@ end @testset for alg in (LAPACK_QRIteration(), LAPACK_DivideAndConquer(), ) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "svd_compact" begin + A = randn(rng, T, m, m) + Ah = (A + A')/2 + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_forward(svd_compact, RT, (Ah, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), fdm=fdm) + end + end + #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) @testset "svd_compact" begin U, S, Vᴴ = svd_compact(A) ΔU = randn(rng, T, m, minmn) @@ -331,7 +350,7 @@ end dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) end - end + end=# end end end diff --git a/test/mooncake.jl b/test/mooncake.jl index 4263fd25..ee9a39b6 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -30,7 +30,7 @@ function make_mooncake_tangent(ΔD::Diagonal{T}) where {T<:Complex} return Mooncake.build_tangent(typeof(ΔD), diag_tangent) end -ETs = (Float64, Float32, ComplexF64, ComplexF32) +ETs = (Float64, Float32,)# ComplexF64, ComplexF32) @timedtestset "QR AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) @@ -43,13 +43,13 @@ ETs = (Float64, Float32, ComplexF64, ComplexF32) LAPACK_HouseholderQR(; positive=true), ) @testset "qr_compact" begin - Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol) end @testset "qr_null" begin Q, R = qr_compact(A, alg) ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode=Mooncake.ReverseMode, output_tangent = dN, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; output_tangent = dN, is_primitive=false, atol=atol, rtol=rtol) end @testset "qr_full" begin Q, R = qr_full(A, alg) @@ -61,7 +61,7 @@ ETs = (Float64, Float32, ComplexF64, ComplexF32) dQ = make_mooncake_tangent(ΔQ) dR = make_mooncake_tangent(ΔR) dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode=Mooncake.ReverseMode, output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol) end @testset "qr_compact - rank-deficient A" begin r = minmn - 5 @@ -77,7 +77,7 @@ ETs = (Float64, Float32, ComplexF64, ComplexF32) dQ = make_mooncake_tangent(ΔQ) dR = make_mooncake_tangent(ΔR) dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; mode=Mooncake.ReverseMode, output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol) end end end @@ -158,7 +158,9 @@ end # compute the dA corresponding to the above dD, dV @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) @testset "eig_full" begin - Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode=Mooncake.ReverseMode, output_tangent = dDV, is_primitive=false, atol=atol, rtol=rtol) + Ah = (A + A')/2 + Mooncake.TestUtils.test_rule(rng, (A, alg) -> eig_full(exp(A), alg)[1], Ah, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, (A, alg) -> eig_full(exp(A), alg)[2], Ah, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) end @testset "eig_vals" begin Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; atol=atol, rtol=rtol, is_primitive=false) @@ -283,24 +285,42 @@ end A = randn(rng, T, m, n) minmn = min(m, n) @testset for alg in (LAPACK_QRIteration(), - LAPACK_DivideAndConquer(), + #LAPACK_DivideAndConquer(), ) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) + ΔUref = copy(ΔU) + ΔSref = copy(ΔS2) + ΔVᴴref = copy(ΔVᴴ) @testset "svd_compact" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=atol) dS = make_mooncake_tangent(ΔS2) dU = make_mooncake_tangent(ΔU) dVᴴ = make_mooncake_tangent(ΔVᴴ) dUSVᴴ = Mooncake.build_tangent(typeof((ΔU,ΔS2,ΔVᴴ)), dU, dS, dVᴴ) Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴ, is_primitive=false, atol=atol, rtol=rtol) end + @testset "svd_full" begin + ΔUfull = zeros(T, m, m) + ΔSfull = zeros(real(T), m, n) + ΔVᴴfull = zeros(T, n, n) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔUref + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴref + diagview(ΔSfull)[1:minmn] .= diagview(ΔSref) + dS = make_mooncake_tangent(ΔSfull) + dU = make_mooncake_tangent(ΔUfull) + dVᴴ = make_mooncake_tangent(ΔVᴴfull) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull,ΔSfull,ΔVᴴfull)), dU, dS, dVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴ, is_primitive=false, atol=atol, rtol=rtol) + end @testset "svd_vals" begin Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; is_primitive=false, atol=atol, rtol=rtol) end + #= @testset "svd_trunc" begin @testset for r in 1:4:minmn U, S, Vᴴ = svd_compact(A) @@ -348,7 +368,7 @@ end dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end + end=# end end end @@ -361,10 +381,10 @@ end A = randn(rng, T, m, n) @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) m >= n && - Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; is_primitive=false, atol=atol, rtol=rtol) m <= n && - Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; is_primitive=false, atol=atol, rtol=rtol) end end end @@ -375,26 +395,27 @@ end @testset "size ($m, $n)" for n in (17, m, 23) atol = rtol = m * n * precision(T) A = randn(rng, T, m, n) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + Mooncake.TestUtils.test_rule(rng, left_orth, A; atol=atol, rtol=rtol, is_primitive=false) + Mooncake.TestUtils.test_rule(rng, right_orth, A; atol=atol, rtol=rtol, is_primitive=false) - Mooncake.TestUtils.test_rule(rng, (X->left_orth(X; kind=:qr)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + Mooncake.TestUtils.test_rule(rng, (X->left_orth(X; kind=:qr)), A; atol=atol, rtol=rtol, is_primitive=false) if m >= n - Mooncake.TestUtils.test_rule(rng, (X->left_orth(X; kind=:polar)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + Mooncake.TestUtils.test_rule(rng, (X->left_orth(X; kind=:polar)), A; atol=atol, rtol=rtol, is_primitive=false) end ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, (X->left_null(X; kind=:qr)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false, output_tangent = dN) + Mooncake.TestUtils.test_rule(rng, (X->left_null(X; kind=:qr)), A; atol=atol, rtol=rtol, is_primitive=false, output_tangent = dN) - Mooncake.TestUtils.test_rule(rng, (X->right_orth(X; kind=:lq)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + Mooncake.TestUtils.test_rule(rng, (X->right_orth(X; kind=:lq)), A; atol=atol, rtol=rtol, is_primitive=false) if m <= n - Mooncake.TestUtils.test_rule(rng, (X->right_orth(X; kind=:polar)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false) + Mooncake.TestUtils.test_rule(rng, (X->right_orth(X; kind=:polar)), A; atol=atol, rtol=rtol, is_primitive=false) end ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, (X->right_null(X; kind=:lq)), A; mode=Mooncake.ReverseMode, atol=atol, rtol=rtol, is_primitive=false, output_tangent = dNᴴ) + Mooncake.TestUtils.test_rule(rng, (X->right_null(X; kind=:lq)), A; atol=atol, rtol=rtol, is_primitive=false, output_tangent = dNᴴ) end end + diff --git a/test/runtests.jl b/test/runtests.jl index 2f39856f..2e1f7b0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using SafeTestsets # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - @safetestset "Algorithms" begin + #=@safetestset "Algorithms" begin include("algorithms.jl") end @safetestset "Projections" begin @@ -38,12 +38,14 @@ if !is_buildkite @safetestset "Image and Null Space" begin include("orthnull.jl") end + =# @safetestset "Mooncake" begin include("mooncake.jl") end - @safetestset "Enzyme" begin + #=@safetestset "Enzyme" begin include("enzyme.jl") - end + end=# + #= @safetestset "ChainRules" begin include("chainrules.jl") end @@ -58,7 +60,7 @@ if !is_buildkite using JET JET.test_package(MatrixAlgebraKit; target_defined_modules = true) end - end + end=# end using CUDA From 78cfd67eb1982645bc3294e98dcab6aa7a3c998f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 27 Oct 2025 15:02:41 +0100 Subject: [PATCH 4/6] Some QR progress --- .../MatrixAlgebraKitMooncakeExt.jl | 20 +++--- src/pushforwards/lq.jl | 72 +++++++++++-------- src/pushforwards/qr.jl | 33 ++++----- test/mooncake.jl | 46 ++++++------ 4 files changed, 88 insertions(+), 83 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 76af0dc9..9408a7b3 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -25,12 +25,12 @@ for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoi @eval begin @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) - A, dA = arrayify(A_dA) - dA .= zero(eltype(A)) - args = Mooncake.primal(args_dargs) - dargs = Mooncake.tangent(args_dargs) - arg1, darg1 = arrayify(args[1], dargs[1]) - arg2, darg2 = arrayify(args[2], dargs[2]) + A, dA = arrayify(A_dA) + dA .= zero(eltype(A)) + args = Mooncake.primal(args_dargs) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) function $adj(::Mooncake.NoRData) dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...) return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData() @@ -42,10 +42,10 @@ for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoi end @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...) - A, dA = arrayify(A_dA) - args = Mooncake.primal(args_dargs) - args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...) - dargs = Mooncake.tangent(args_dargs) + A, dA = arrayify(A_dA) + args = Mooncake.primal(args_dargs) + args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...) + dargs = Mooncake.tangent(args_dargs) arg1, darg1 = arrayify(args[1], dargs[1]) arg2, darg2 = arrayify(args[2], dargs[2]) darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2)) diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl index 1a220628..2d390a59 100644 --- a/src/pushforwards/lq.jl +++ b/src/pushforwards/lq.jl @@ -1,11 +1,19 @@ -function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) +#=function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) L, Q = LQ + dL, dQ = dLQ m = size(L, 1) n = size(Q, 2) minmn = min(m, n) Ld = diagview(L) p = findlast(>=(rank_atol) ∘ abs, Ld) + if p == minmn && size(L,1) == size(L,2) # full-rank + invL = inv(L) + dQ .= invL * (dA - dL * Q) + dL = invL * dA * Q' + return (dL, dQ) + end + n1 = p n2 = minmn - p n3 = n - minmn @@ -13,46 +21,48 @@ function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pull m2 = m - p ##### - Q1 = view(Q, 1:n1, 1:n) # full rank portion - Q2 = view(Q, 1:n1+1:n2+n1, 1:n) - L11 = view(L, 1:m, 1:n1) - L12 = view(L, 1:m1, n1+1:n) + Q1 = view(Q, 1:m1, 1:n) # full rank portion + Q2 = view(Q, n1+1:n1+n2, 1:n) + L11 = view(L, 1:m1, 1:n1) + L21 = view(L, (m1+1):m, 1:n1) - dA1 = view(dA, 1:m, 1:n1) - dA2 = view(dA, 1:m, (n1 + 1):n) + dA1 = view(dA, 1:m1, 1:n) + dA2 = view(dA, (m1+1):m, 1:n) - dQ, dR = dQR - dQ1 = view(dQ, 1:m, 1:m1) - dQ2 = view(dQ, 1:m, m1+1:m2+m1) - dR11 = view(dR, 1:m1, 1:n1) - dR12 = view(dR, 1:m1, n1+1:n) - dR22 = view(dR, m1+1:m1+m2, n1+1:n) + dQ1 = view(dQ, 1:n1, 1:n) + dQ2 = view(dQ, n1+1:n1+n2, 1:n) + dL11 = view(dL, 1:m1, 1:n1) + dL21 = view(dL, (m1+1):m, 1:n1) + dL22 = view(dL, (m1+1):m, n1+1:(n1+n2) ) # fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need - invR11 = inv(R11) - tmp = Q1' * dA1 * invR11 - Rtmp = tmp + tmp' - diagview(Rtmp) ./= 2 - ltRtmp = view(Rtmp, MatrixAlgebraKit.lowertriangularind(Rtmp)) - #ltRtmp .= zero(eltype(Rtmp)) - dR11 .= Rtmp * R11 - dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 - - dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12) - dQ2 .= Q1 * (Q1' * dQ2) - if size(Q2, 2) > 0 + invL11 = inv(L11) + tmp = invL11 * dA1 * Q1' + Ltmp = tmp + tmp' + diagview(Ltmp) ./= 2 + utLtmp = view(Ltmp, MatrixAlgebraKit.uppertriangularind(Ltmp)) + dL11 .= L11 * Ltmp + dQ1 .= invL11 * dA1 - invL11 * dL11 * Q1 + + dL21 .= (dA2 - L21 * dQ1) * adjoint(Q1) + dQ2 .= -(dQ2 * Q1') * Q1 + if size(Q2, 1) > 0 dQ2 .+= Q2 * (Q2' * dQ2) end - if m3 > 0 && size(dQ2, 2) > 0 + if n3 > 0 && size(dQ2, 1) > 0 # only present for qr_full or rank-deficient qr_compact - Q3 = view(Q, 1:m, m1+m2+1:size(Q, 2)) + Q3 = view(Q, (n1+n2+1):n, 1:n) dQ2 .+= Q3 * (Q3' * dQ2) end - if !isempty(dR22) - _, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) - dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2)) + if !isempty(dL22) + _, l22 = qr_full(dA2 - L21 * dQ1 - dL12 * Q1, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) + dL22 .= view(l22, 1:size(dL22, 1), 1:size(dL22, 2)) end - return (dQ, dR) + return (dL, dQ) +end=# + +function lq_pushforward!(dA, A, LQ, dLQ; kwargs...) + qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...) end function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/src/pushforwards/qr.jl b/src/pushforwards/qr.jl index c2068337..fbc8e017 100644 --- a/src/pushforwards/qr.jl +++ b/src/pushforwards/qr.jl @@ -1,7 +1,7 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) Q, R = QR - m = size(Q, 1) - n = size(R, 2) + m = size(A, 1) + n = size(A, 2) minmn = min(m, n) Rd = diagview(R) p = findlast(>=(rank_atol) ∘ abs, Rd) @@ -23,7 +23,7 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pull dQ, dR = dQR dQ1 = view(dQ, 1:m, 1:m1) dQ2 = view(dQ, 1:m, m1+1:m2+m1) - dQ3 = m1+m2+1 < size(dQ, 2) ? view(dQ, 1:m, m1+m2+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0)) + dQ3 = minmn+1 < size(dQ, 2) ? view(dQ, :, minmn+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0)) dR11 = view(dR, 1:m1, 1:n1) dR12 = view(dR, 1:m1, n1+1:n) dR22 = view(dR, m1+1:m1+m2, n1+1:n) @@ -38,32 +38,23 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pull dR11 .= Rtmp * R11 dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12) - dQ2 .= Q1 * (Q1' * dQ2) + dQ2 .= -Q1 * (Q1' * dQ2) if size(Q2, 2) > 0 dQ2 .+= Q2 * (Q2' * dQ2) end - if m3 > 0 && size(dQ2, 2) > 0 + if m3 > 0 && size(Q, 2) > minmn # only present for qr_full or rank-deficient qr_compact - Q3 = view(Q, 1:m, m1+m2+1:size(Q, 2)) - dQ2 .+= Q3 * (Q3' * dQ2) + Q′ = view(Q, :, 1:minmn) + println("minmn $minmn m $m") + Q3 = view(Q, :, minmn+1:m) + #dQ3 .= Q′ * (Q′' * Q3) + dQ3 .= Q3 end - if !isempty(dR22) + #=if !isempty(dR22) _, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2)) - end + end=# return (dQ, dR) end -#=Ac = MatrixAlgebraKit.copy_input(qr_full, Aval) -QR = MatrixAlgebraKit.initialize_output(qr_full!, Aval, alg.val) -Q, R = qr_full!(Ac, QR, alg.val) -Nval = N.val -copy!(Nval, view(Q, 1:size(Aval, 1), (size(Aval, 2) + 1):size(Aval, 1))) -(m, n) = size(Aval) -minmn = min(m, n) -dQ = zeros(eltype(Aval), (m, m)) -view(dQ, 1:m, (minmn + 1):m) .= dN -MatrixAlgebraKit.qr_fwd(dA, A.val, (Q, R), (dQ, zeros(eltype(R), size(R)))) -dN .= view(dQ, 1:m, (minmn + 1):m) -dA .= zero(eltype(A.val))=# function qr_null_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/test/mooncake.jl b/test/mooncake.jl index ee9a39b6..f73c4db4 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -30,27 +30,27 @@ function make_mooncake_tangent(ΔD::Diagonal{T}) where {T<:Complex} return Mooncake.build_tangent(typeof(ΔD), diag_tangent) end -ETs = (Float64, Float32,)# ComplexF64, ComplexF32) +ETs = (Float64,)# Float32,)# ComplexF64, ComplexF32) @timedtestset "QR AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) + @testset "size ($m, $n)" for n in (17,)# m, 23) atol = rtol = m * n * precision(T) A = randn(rng, T, m, n) minmn = min(m, n) @testset for alg in (LAPACK_HouseholderQR(), - LAPACK_HouseholderQR(; positive=true), + #LAPACK_HouseholderQR(; positive=true), ) - @testset "qr_compact" begin + #=@testset "qr_compact" begin Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol) - end - @testset "qr_null" begin + end=# + #=@testset "qr_null" begin Q, R = qr_compact(A, alg) ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) dN = make_mooncake_tangent(ΔN) Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; output_tangent = dN, is_primitive=false, atol=atol, rtol=rtol) - end + end=# @testset "qr_full" begin Q, R = qr_full(A, alg) Q1 = view(Q, 1:m, 1:minmn) @@ -61,9 +61,11 @@ ETs = (Float64, Float32,)# ComplexF64, ComplexF32) dQ = make_mooncake_tangent(ΔQ) dR = make_mooncake_tangent(ΔR) dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol) + #Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + #Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, 1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, minmn+1:m]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) end - @testset "qr_compact - rank-deficient A" begin + #=@testset "qr_compact - rank-deficient A" begin r = minmn - 5 Ard = randn(rng, T, m, r) * randn(rng, T, r, n) Q, R = qr_compact(Ard, alg) @@ -77,12 +79,13 @@ ETs = (Float64, Float32,)# ComplexF64, ComplexF32) dQ = make_mooncake_tangent(ΔQ) dR = make_mooncake_tangent(ΔR) dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol) - end + Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + end=# end end end +#= @timedtestset "LQ AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -99,14 +102,14 @@ end dL = make_mooncake_tangent(ΔL) dQ = make_mooncake_tangent(ΔQ) dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol, output_tangent = dLQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol, output_tangent = dLQ) end - @testset "lq_null" begin + #=@testset "lq_null" begin L, Q = lq_compact(A, alg) ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode=Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive=false, atol=atol, rtol=rtol) - end + Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; output_tangent = dNᴴ, is_primitive=false, atol=atol, rtol=rtol) + end=# @testset "lq_full" begin L, Q = lq_full(A, alg) Q1 = view(Q, 1:minmn, 1:n) @@ -117,9 +120,9 @@ end dL = make_mooncake_tangent(ΔL) dQ = make_mooncake_tangent(ΔQ) dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode=Mooncake.ReverseMode, output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol) end - @testset "lq_compact - rank-deficient A" begin + #=@testset "lq_compact - rank-deficient A" begin r = minmn - 5 Ard = randn(rng, T, m, r) * randn(rng, T, r, n) L, Q = lq_compact(Ard, alg) @@ -133,12 +136,13 @@ end dL = make_mooncake_tangent(ΔL) dQ = make_mooncake_tangent(ΔQ) dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode=Mooncake.ReverseMode, output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol) - end + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol) + end=# end end end - +=# +#= @timedtestset "EIG AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -418,4 +422,4 @@ end Mooncake.TestUtils.test_rule(rng, (X->right_null(X; kind=:lq)), A; atol=atol, rtol=rtol, is_primitive=false, output_tangent = dNᴴ) end end - +=# From d64d2c5ad1947c199bafb0e64cefa78f5262e144 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 6 Nov 2025 11:35:36 +0100 Subject: [PATCH 5/6] Working frule for polar --- .../MatrixAlgebraKitMooncakeExt.jl | 15 +++++++++-- src/pushforwards/polar.jl | 27 +++++++++++++++++-- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 9408a7b3..17971578 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -216,7 +216,7 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal)) vS = Diagonal(diagview(S)[1:minmn]) vVᴴ = view(Vᴴ, 1:minmn, :) vdU = view(dU, :, 1:minmn) - vdS = Diagonal(diagview(dS)[1:minmn]) + vdS = view(dS, 1:minmn, 1:minmn) vdVᴴ = view(dVᴴ, 1:minmn, :) dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end @@ -243,7 +243,18 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal)) U, dU = arrayify(U_, dU_) S, dS = arrayify(S_, dS_) Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_) - (dU, dS, dVᴴ) = svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...) + minmn = min(size(A)...) + if ($f == svd_compact!) # compact + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...) + else # full + vU = view(U, :, 1:minmn) + vS = view(S, 1:minmn, 1:minmn) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = view(dS, 1:minmn, 1:minmn) + vdVᴴ = view(dVᴴ, 1:minmn, :) + svd_pushforward!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ); kwargs...) + end return USVᴴ_dUSVᴴ end end diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl index 803771df..bb9c801e 100644 --- a/src/pushforwards/polar.jl +++ b/src/pushforwards/polar.jl @@ -1,2 +1,25 @@ -function left_polar_pushforward! end -function right_polar_pushforward! end +function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) + W, P = WP + ΔW, ΔP = ΔWP + aWdA = adjoint(W) * ΔA + K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA))) + L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P) + ΔW .= W * K̇ + L̇ + ΔP .= aWdA - K̇*P + MatrixAlgebraKit.zero!(ΔA) + return (ΔW, ΔP) +end + +function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) + P, Wᴴ = PWᴴ + ΔP, ΔWᴴ = ΔPWᴴ + dAW = ΔA * adjoint(Wᴴ) + K̇ = sylvester(P, P, -(dAW - adjoint(dAW))) + ImW = (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) + @show size(P), size(ΔA), size(ImW), size(Wᴴ) + L̇ = inv(P)*ΔA*ImW + ΔWᴴ .= K̇ * Wᴴ + L̇ + ΔP .= dAW - P * K̇ + MatrixAlgebraKit.zero!(ΔA) + return (ΔWᴴ, ΔP) +end From 891ae98d61d30aeda13bdc444d06ae61d059ddba Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 7 Nov 2025 15:49:56 +0100 Subject: [PATCH 6/6] Some pushforwards progress --- Project.toml | 2 +- .../MatrixAlgebraKitEnzymeExt.jl | 60 +------ src/pushforwards/eig.jl | 20 +-- src/pushforwards/eigh.jl | 2 + src/pushforwards/polar.jl | 4 +- src/pushforwards/qr.jl | 9 +- src/pushforwards/svd.jl | 100 +++++++++--- test/enzyme.jl | 152 +++++++----------- 8 files changed, 153 insertions(+), 196 deletions(-) diff --git a/Project.toml b/Project.toml index ec057141..a5afbc17 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ ChainRulesTestUtils = "1" CUDA = "5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" -Enzyme = "0.13.77" +Enzyme = "0.13.96" EnzymeTestUtils = "0.2.3" JET = "0.9, 0.10" LinearAlgebra = "1" diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 0c2330a2..a80dce67 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -5,6 +5,7 @@ using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc! using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward! +using MatrixAlgebraKit: svd_pushforward! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward! using Enzyme using Enzyme.EnzymeCore @@ -179,23 +180,7 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ) where {RT} ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing shadow = if EnzymeRules.needs_shadow(config) - U, S, Vᴴ = ret - V = adjoint(Vᴴ) - ∂S = Diagonal(diag(real.(U' * A.dval * V))) - m, n = size(A.val) - F = one(eltype(S)) ./ ((diagview(S).^2)' .- (diagview(S) .^ 2)) - diagview(F) .= zero(eltype(F)) - invSdiag = zeros(eltype(S), length(S.diag)) - for i in 1:length(S.diag) - @inbounds invSdiag[i] = inv(diagview(S)[i]) - end - invS = Diagonal(invSdiag) - ∂U = U * (F .* (U' * A.dval * V * S + S * Vᴴ * A.dval' * U)) + (diagm(ones(eltype(U), m)) - U*U') * A.dval * V * invS - #∂Vᴴ = (FSdS' * Vᴴ) + (invS * U' * A.dval * (diagm(ones(eltype(U), size(V, 2))) - Vᴴ*V)) - ∂V = V * (F .* (S * U' * A.dval * V + Vᴴ * A.dval' * U * S)) + (diagm(ones(eltype(V), n)) - V*Vᴴ) * A.dval' * U * invS - ∂Vᴴ = similar(Vᴴ) - adjoint!(∂Vᴴ, ∂V) - (∂U, ∂S, ∂Vᴴ) + svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval) else nothing end @@ -221,46 +206,7 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig, ) where {RT} ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing shadow = if EnzymeRules.needs_shadow(config) - fatU, fatS, fatVᴴ = ret - ∂Ufat = zeros(eltype(fatU), size(fatU)) - ∂Sfat = zeros(eltype(fatS), size(fatS)) - ∂Vᴴfat = zeros(eltype(fatVᴴ), size(fatVᴴ)) - m, n = size(A.val) - minmn = min(m, n) - #U = view(fatU, :, 1:minmn) - #S = Diagonal(diagview(fatS)) - #Vᴴ = view(fatVᴴ, 1:minmn, :) - U = fatU - S = fatS - Vᴴ = fatVᴴ - V = adjoint(Vᴴ) - ∂S = Diagonal(diag(real.(U' * A.dval * V))) - diagview(∂Sfat) .= diagview(∂S) - m, n = size(A.val) - F = one(eltype(S)) ./ ((diagview(S).^2)' .- (diagview(S) .^ 2)) - diagview(F) .= zero(eltype(F)) - invSdiag = zeros(eltype(S), size(S)) - for ix in diagind(S) - @inbounds invSdiag[ix] = inv(S[ix]) - end - invS = invSdiag - #FSdS = F .* (∂S * S .+ S * ∂S) - ∂U = U * (F .* (U' * A.dval * V * S + S * Vᴴ * A.dval' * U)) + (diagm(ones(eltype(U), m)) - U*U') * A.dval * V * invS - #view(∂Ufat, :, 1:minmn) .= view(∂U, :, :) - ∂Ufat .= ∂U - - - #∂Vᴴ = (FSdS' * Vᴴ) + (invS * U' * A.dval * (diagm(ones(eltype(U), size(V, 2))) - Vᴴ*V)) - ∂V = V * (F .* (S * U' * A.dval * V + Vᴴ * A.dval' * U * S)) + (diagm(ones(eltype(V), n)) - V*Vᴴ) * A.dval' * U * invS - ∂Vᴴ = similar(Vᴴ) - adjoint!(∂Vᴴ, ∂V) - #view(∂Vᴴfat, 1:minmn, :) .= view(∂Vᴴ, :, :) - ∂Vᴴfat .= ∂Vᴴ - #=view(∂Ufat, :, minmn+1:m) .= zero(eltype(fatU)) - view(∂Vᴴfat, minmn+1:n, :) .= zero(eltype(fatVᴴ)) - view(∂Sfat, minmn+1:m, :) .= zero(eltype(fatVᴴ)) - view(∂Sfat, :, minmn+1:n) .= zero(eltype(fatVᴴ))=# - (∂Ufat, ∂Sfat, ∂Vᴴfat) + svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval) else nothing end diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl index 8a6808e5..19a43cb9 100644 --- a/src/pushforwards/eig.jl +++ b/src/pushforwards/eig.jl @@ -1,12 +1,12 @@ -function eig_pushforward!(dA, A, DV, dDV; kwargs...) +function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...) D, V = DV - dD, dV = dDV - ∂K = inv(V) * dA * V - ∂Kdiag = diagview(∂K) - dD.diag .= ∂Kdiag - ∂K ./= transpose(diagview(D)) .- diagview(D) - fill!(∂Kdiag, zero(eltype(D))) - mul!(dV, V, ∂K, 1, 0) - dA .= zero(eltype(dA)) - return dDV + ΔD, ΔV = ΔDV + iVΔAV = inv(V) * ΔA * V + diagview(ΔD) .= diagview(iVΔAV) + F = 1 ./ (transpose(diagview(D)) .- diagview(D)) + fill!(diagview(F), zero(eltype(F))) + K̇ = F .* iVΔAV + mul!(ΔV, V, K̇, 1, 0) + zero!(ΔA) + return ΔDV end diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl index c050a913..69685b16 100644 --- a/src/pushforwards/eigh.jl +++ b/src/pushforwards/eigh.jl @@ -1,4 +1,6 @@ function eigh_pushforward!(dA, A, DV, dDV; kwargs...) + D, V = DV + dD, dV = dDV tmpV = V \ dA ∂K = tmpV * V ∂Kdiag = diag(∂K) diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl index bb9c801e..2001c41a 100644 --- a/src/pushforwards/polar.jl +++ b/src/pushforwards/polar.jl @@ -15,9 +15,7 @@ function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) ΔP, ΔWᴴ = ΔPWᴴ dAW = ΔA * adjoint(Wᴴ) K̇ = sylvester(P, P, -(dAW - adjoint(dAW))) - ImW = (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) - @show size(P), size(ΔA), size(ImW), size(Wᴴ) - L̇ = inv(P)*ΔA*ImW + L̇ = inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) ΔWᴴ .= K̇ * Wᴴ + L̇ ΔP .= dAW - P * K̇ MatrixAlgebraKit.zero!(ΔA) diff --git a/src/pushforwards/qr.jl b/src/pushforwards/qr.jl index fbc8e017..26672559 100644 --- a/src/pushforwards/qr.jl +++ b/src/pushforwards/qr.jl @@ -38,23 +38,22 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pull dR11 .= Rtmp * R11 dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12) - dQ2 .= -Q1 * (Q1' * dQ2) if size(Q2, 2) > 0 + dQ2 .= -Q1 * (Q1' * Q2) dQ2 .+= Q2 * (Q2' * dQ2) end if m3 > 0 && size(Q, 2) > minmn # only present for qr_full or rank-deficient qr_compact Q′ = view(Q, :, 1:minmn) - println("minmn $minmn m $m") Q3 = view(Q, :, minmn+1:m) #dQ3 .= Q′ * (Q′' * Q3) dQ3 .= Q3 end - #=if !isempty(dR22) + if !isempty(dR22) _, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2)) - end=# + end return (dQ, dR) end -function qr_null_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) end +function qr_null_pushforward!(dA, A, N, dN; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(N), rank_atol::Real=tol, gauge_atol::Real=tol) end diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl index f6cf0c41..f6abc689 100644 --- a/src/pushforwards/svd.jl +++ b/src/pushforwards/svd.jl @@ -1,33 +1,83 @@ -function svd_pushforward!(dA, A, USVᴴ, dUSVᴴ; +function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; tol::Real = default_pullback_gaugetol(USVᴴ[2]), rank_atol::Real = tol, degeneracy_atol::Real = tol, gauge_atol::Real = tol ) - U, S, Vᴴ = USVᴴ - dU, dS, dVᴴ = dUSVᴴ - V = adjoint(Vᴴ) - UdAV = U' * dA * V - copyto!(diagview(dS), diag(real.(UdAV))) - m, n = size(A) - F = one(eltype(S)) ./ (diagview(S)' .- diagview(S)) - G = one(eltype(S)) ./ (diagview(S)' .+ diagview(S)) + U, Smat, Vᴴ = USVᴴ + m, n = size(U, 1), size(Vᴴ, 2) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) + minmn = min(m, n) + S = diagview(Smat) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + r = searchsortedlast(S, rank_atol; rev = true) # rank + + vΔU = view(ΔU, :, 1:r) + vΔS = view(ΔS, 1:r, 1:r) + vΔVᴴ = view(ΔVᴴ, 1:r, :) + + vU = view(U, :, 1:r) + vS = view(S, 1:r) + vSmat = view(Smat, 1:r, 1:r) + vVᴴ = view(Vᴴ, 1:r, :) + + # compact region + vV = adjoint(vVᴴ) + UΔAV = vU' * ΔA * vV + copyto!(diagview(vΔS), diag(real.(UΔAV))) + F = one(eltype(S)) ./ (transpose(vS) .- vS) + G = one(eltype(S)) ./ (transpose(vS) .+ vS) diagview(F) .= zero(eltype(F)) - invSdiag = zeros(eltype(S), length(diagview(S))) - for i in 1:length(diagview(S)) - @inbounds invSdiag[i] = inv(diagview(S)[i]) + hUΔAV = F .* (UΔAV + UΔAV') ./ 2 + aUΔAV = G .* (UΔAV - UΔAV') ./ 2 + K̇ = hUΔAV + aUΔAV + Ṁ = hUΔAV - aUΔAV + + # check gauge condition + @assert isantihermitian(K̇) + @assert isantihermitian(Ṁ) + K̇diag = diagview(K̇) + for i in 1:length(K̇diag) + @assert K̇diag[i] ≈ (im/2) * imag(diagview(UΔAV)[i])/S[i] end - invS = Diagonal(invSdiag) - #∂U = U * (F .* (U' * dA * V * S + S * Vᴴ * dA' * U)) + (LinearAlgebra.diagm(ones(eltype(U), m)) - U*U') * dA * V * invS - #∂V = V * (F .* (S * U' * dA * V + Vᴴ * dA' * U * S)) + (LinearAlgebra.diagm(ones(eltype(V), n)) - V*Vᴴ) * dA' * U * invS - hUdAV = F .* project_hermitian(UdAV) - aUdAV = G .* project_antihermitian(UdAV) - ∂U = U * (hUdAV + aUdAV) - ∂U += (LinearAlgebra.diagm(ones(eltype(U), m)) - U*U') * dA * V * invS - ∂V = V * (hUdAV - aUdAV) - ∂V += (LinearAlgebra.diagm(ones(eltype(U), n)) - V*V') * dA' * U * invS - copyto!(dU, ∂U) - adjoint!(dVᴴ, ∂V) - dA .= zero(eltype(A)) - return (dU, dS, dVᴴ) + + ∂U = vU * K̇ + ∂V = vV * Ṁ + # full component + if size(U, 2) > minmn && size(Vᴴ, 1) > minmn + Uperp = view(U, :, minmn+1:m) + Vᴴperp = view(Vᴴ, minmn+1:n, :) + + aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp) + + UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2))) + fill!(UÃÃV, 0) + view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV + view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' + rhs = vcat( adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U) + superKM = -sylvester(UÃÃV, Smat, rhs) + K̇perp = view(superKM, 1:size(aUAV, 2)) + Ṁperp = view(superKM, size(aUAV, 2)+1:size(aUAV, 1)+size(aUAV, 2)) + ∂U .+= Uperp * K̇perp + ∂V .+= Vperp * Ṁperp + else + ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU*vU') + ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV*vVᴴ) + upper = ImUU * ΔA * vV + lower = ImVV * ΔA' * vU + rhs = vcat(upper, lower) + + Ã = ImUU * A * ImVV + ÃÃ = similar(A, (m + n, m + n)) + fill!(ÃÃ, 0) + view(ÃÃ, (1:m), m .+ (1:n)) .= Ã + view(ÃÃ, m .+ (1:n), 1:m ) .= Ã' + + superLN = -sylvester(ÃÃ, vSmat, rhs) + ∂U += view(superLN, 1:size(upper, 1), :) + ∂V += view(superLN, size(upper, 1)+1:size(upper,1)+size(lower,1), :) + end + copyto!(vΔU, ∂U) + adjoint!(vΔVᴴ, ∂V) + return (ΔU, ΔS, ΔVᴴ) end diff --git a/test/enzyme.jl b/test/enzyme.jl index b9b16c45..fc4435d1 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -2,7 +2,7 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using ChainRulesCore +using LinearAlgebra using Enzyme, EnzymeTestUtils using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! @@ -12,7 +12,7 @@ is_ci = get(ENV, "CI", "false") == "true" #ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631 ETs = (Float64, ComplexF64) include("ad_utils.jl") -#= + @timedtestset "QR AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -28,12 +28,12 @@ include("ad_utils.jl") ΔQ = randn(rng, T, m, minmn) ΔR = randn(rng, T, minmn, n) fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - test_reverse(qr_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR), fdm=fdm) + test_forward(qr_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), fdm=fdm) end - @testset "qr_null" begin + #=@testset "qr_null" begin Q, R = qr_compact(A, alg) ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - test_reverse(qr_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=ΔN) + test_forward(qr_null, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) end @testset "qr_full" begin Q, R = qr_full(A, alg) @@ -43,9 +43,9 @@ include("ad_utils.jl") mul!(ΔQ2, Q1, Q1' * ΔQ2) ΔR = randn(rng, T, m, n) fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - test_reverse(qr_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR), fdm=fdm) - end - @testset "qr_compact - rank-deficient A" begin + test_forward(qr_full, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), fdm=fdm) + end=# + #=@testset "qr_compact - rank-deficient A" begin r = minmn - 5 Ard = randn(rng, T, m, r) * randn(rng, T, r, n) Q, R = qr_compact(Ard, alg) @@ -56,13 +56,13 @@ include("ad_utils.jl") ΔQ2 .= 0 ΔR = randn(rng, T, minmn, n) view(ΔR, (r + 1):minmn, :) .= 0 - test_reverse(qr_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔQ, ΔR)) - end + test_forward(qr_compact, RT, (Ard, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) + end=# end end end end - +#= @timedtestset "LQ AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -114,18 +114,7 @@ end end end =# -function MatrixAlgebraKit.eig_trunc!(A, DV, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T} - D, V = eig_full!(A, DV, alg.alg) - DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, (D, V), alg.trunc) - ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(D), ind) - return DVtrunc..., ϵ -end -function dummy_eig_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T} - Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.eig_trunc, A) - DV = MatrixAlgebraKit.initialize_output(eig_trunc!, A, alg) - Dtrunc, Vtrunc, ϵ = MatrixAlgebraKit.eig_trunc!(Ac, DV, ϵ, alg) - return Dtrunc, Vtrunc, ϵ -end + @timedtestset "EIG AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) @@ -139,19 +128,14 @@ end ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=atol) ΔD = randn(rng, complex(T), m, m) ΔD2 = Diagonal(randn(rng, complex(T), m)) - @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) + @testset for alg in (LAPACK_Simple(), + #LAPACK_Expert(), + ) @testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - # make A hermitian - Ah = (A + A')/2 - test_forward((Ah; kwargs...)->eig_full(A; kwargs...)[1], RT, Duplicated(A, ΔA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) - test_forward((Ah; kwargs...)->eig_full(A; kwargs...)[2], RT, Duplicated(A, ΔA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) - #test_forward(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + test_forward(eig_full, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) + test_forward(eig_vals, RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) end #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - test_reverse(eig_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) - test_reverse(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) - end - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc) @@ -159,8 +143,7 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - ϵ = [zero(real(T))] - test_reverse(dummy_eig_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))])) + test_forward(eig_trunc, RT, (A, TA), (truncalg, Const); atol=atol, rtol=rtol) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -171,8 +154,7 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - ϵ = [zero(real(T))] - test_reverse(dummy_eig_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))])) + test_forward(eig_trunc, RT, (A, TA), (truncalg, Const); atol=atol, rtol=rtol) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -190,20 +172,6 @@ function copy_eigh_vals(A; kwargs...) eigh_vals(A; kwargs...) end -function MatrixAlgebraKit.eigh_trunc!(A, DV, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T} - D, V = eigh_full!(A, DV, alg.alg) - DVtrunc, ind = MatrixAlgebraKit.truncate(eigh_trunc!, (D, V), alg.trunc) - ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(D), ind) - return DVtrunc..., ϵ -end -function dummy_eigh_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T} - A = (A + A')/2 - Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.eigh_trunc, A) - DV = MatrixAlgebraKit.initialize_output(eigh_trunc!, A, alg) - Dtrunc, Vtrunc, ϵ = MatrixAlgebraKit.eigh_trunc!(Ac, DV, ϵ, alg) - return Dtrunc, Vtrunc, ϵ -end - @timedtestset "EIGH AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -217,20 +185,16 @@ end ΔD = randn(rng, real(T), m, m) ΔD2 = Diagonal(randn(rng, real(T), m)) @testset for alg in (LAPACK_QRIteration(), - LAPACK_DivideAndConquer(), - LAPACK_Bisection(), - LAPACK_MultipleRelativelyRobustRepresentations(), + #LAPACK_DivideAndConquer(), + #LAPACK_Bisection(), + #LAPACK_MultipleRelativelyRobustRepresentations(), ) - @testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,) + @testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) test_forward((A; kwargs...)->copy_eigh_full(A; kwargs...)[1], RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) test_forward((A; kwargs...)->copy_eigh_full(A; kwargs...)[2], RT, (A, TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) test_forward(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol) end #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - test_reverse(copy_eigh_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV))) - test_reverse(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag)) - end - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) for r in 1:4:m Ddiag = diagview(D) truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) @@ -239,8 +203,7 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - ϵ = [zero(real(T))] - test_reverse(dummy_eigh_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))])) + test_forward(eigh_trunc, RT, (A, TA), (truncalg, Const); atol=atol, rtol=rtol) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -252,8 +215,7 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - ϵ = [zero(real(T))] - test_reverse(dummy_eigh_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))])) + test_forward(eigh_trunc, RT, (A, TA), (truncalg, Const); atol=atol, rtol=rtol) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -261,19 +223,6 @@ end end end -function MatrixAlgebraKit.svd_trunc!(A, USVᴴ, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T} - U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) - USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(S), ind) - return USVᴴtrunc..., ϵ -end -function dummy_svd_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T} - Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.svd_trunc, A) - USVᴴ = MatrixAlgebraKit.initialize_output(svd_trunc!, A, alg) - Utrunc, Strunc, Vᴴtrunc, ϵ = MatrixAlgebraKit.svd_trunc!(Ac, USVᴴ, ϵ, alg) - return Utrunc, Strunc, Vᴴtrunc, ϵ -end - @timedtestset "SVD AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -282,15 +231,29 @@ end A = randn(rng, T, m, n) minmn = min(m, n) @testset for alg in (LAPACK_QRIteration(), - LAPACK_DivideAndConquer(), + #LAPACK_DivideAndConquer(), ) @testset "forward: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + function mak_pinv_compact(A, alg) + U, S, Vᴴ = svd_compact(A, alg) + return U * inv(S) * Vᴴ + end + function mak_pinv_full(A, alg) + U, S, Vᴴ = svd_full(A, alg) + return U * inv(S) * Vᴴ + end @testset "svd_compact" begin A = randn(rng, T, m, m) - Ah = (A + A')/2 fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - test_forward(svd_compact, RT, (Ah, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), fdm=fdm) + test_forward(mak_pinv_compact, RT, (A, TA), (alg, Const); atol=atol, rtol=rtol, fdm=fdm) + test_forward((A, alg)->svd_compact(A, alg)[2], RT, (A, TA), (alg, Const); atol=atol, rtol=rtol, fdm=fdm) end + #=@testset "svd_full" begin # horrible Enzyme error here + A = randn(rng, T, m, m) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_forward(mak_pinv_full, RT, (A, TA), (alg, Const); atol=atol, rtol=rtol, fdm=fdm) + test_forward((A, alg)->svd_full(A, alg)[2], RT, (A, TA), (alg, Const); atol=atol, rtol=rtol, fdm=fdm) + end=# end #=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) @testset "svd_compact" begin @@ -380,25 +343,24 @@ end A = randn(rng, T, m, n) @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) @testset "left_orth" begin - @testset for kind in (:polar, :qr) - n > m && kind == :polar && continue - test_reverse(left_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(kind=kind,)) + @testset for alg in (:polar, :qr,) + n > m && alg == :polar && continue + test_forward(left_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) end end - @testset "right_orth" begin - @testset for kind in (:polar, :lq) - n < m && kind == :polar && continue - test_reverse(right_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(kind=kind,)) + #=@testset "right_orth" begin + @testset for alg in (:polar, :lq) + n < m && alg == :polar && continue + test_reverse(right_orth, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,)) end - end - @testset "left_null" begin - ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - test_reverse(left_null, RT, (A, TA); fkwargs=(; kind=:qr), output_tangent=ΔN, atol=atol, rtol=rtol) - end - @testset "right_null" begin - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] - test_reverse(right_null, RT, (A, TA); fkwargs=(; kind=:lq), output_tangent=ΔNᴴ, atol=atol, rtol=rtol) - end + end=# + #=@testset "left_null" begin + test_forward(left_null, RT, (A, TA); fkwargs=(; alg=:qr), atol=atol, rtol=rtol) + end=# + #=@testset "right_null" begin + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg=:lq)[2] + test_reverse(right_null, RT, (A, TA); fkwargs=(; alg=:lq), output_tangent=ΔNᴴ, atol=atol, rtol=rtol) + end=# end end end