Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
MatrixAlgebraKitCUDAExt = "CUDA"
MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra"
MatrixAlgebraKitGenericSchurExt = "GenericSchur"
MatrixAlgebraKitEnzymeExt = "Enzyme"
MatrixAlgebraKitMooncakeExt = "Mooncake"

[compat]
AMDGPU = "2"
Expand All @@ -28,8 +32,11 @@ ChainRulesTestUtils = "1"
CUDA = "5"
GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
Enzyme = "0.13.96"
EnzymeTestUtils = "0.2.3"
JET = "0.9, 0.10"
LinearAlgebra = "1"
Mooncake = "0.4.167"
SafeTestsets = "0.1"
StableRNGs = "1"
Test = "1"
Expand All @@ -40,12 +47,14 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Enzyme", "EnzymeTestUtils", "Mooncake"]
669 changes: 669 additions & 0 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Large diffs are not rendered by default.

300 changes: 300 additions & 0 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
module MatrixAlgebraKitMooncakeExt

using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview
using MatrixAlgebraKit: svd_pushforward!
using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward!
using LinearAlgebra

# two-argument factorizations like LQ, QR, EIG
for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
(qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
(lq_full!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
(lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
(eig_full!, eig_pullback!, eig_pushforward!, :deig_adjoint),
(eigh_full!, eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
(left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
(right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
)

@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
A, dA = arrayify(A_dA)
dA .= zero(eltype(A))
args = Mooncake.primal(args_dargs)
dargs = Mooncake.tangent(args_dargs)
arg1, darg1 = arrayify(args[1], dargs[1])
arg2, darg2 = arrayify(args[2], dargs[2])
function $adj(::Mooncake.NoRData)
dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
darg1 .= zero(eltype(arg1))
darg2 .= zero(eltype(arg2))
return Mooncake.CoDual(args, dargs), $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
A, dA = arrayify(A_dA)
args = Mooncake.primal(args_dargs)
args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
dargs = Mooncake.tangent(args_dargs)
arg1, darg1 = arrayify(args[1], dargs[1])
arg2, darg2 = arrayify(args[2], dargs[2])
darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2))
dA .= zero(eltype(A))
return Mooncake.Dual(args, dargs)
end
end
end

for (f, f_full, pb, pf, adj) in ((qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint),
(lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
A, dA = arrayify(A_dA)
Ac = MatrixAlgebraKit.copy_input($f_full, A)
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
function $adj(::Mooncake.NoRData)
dA .= zero(eltype(A))
$pb(dA, A, arg, darg; kwargs...)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
return arg_darg, $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, arg_darg::Dual{<:AbstractMatrix}, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
A, dA = arrayify(A_dA)
Ac = MatrixAlgebraKit.copy_input($f_full, A)
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
$pf(dA, A, arg, darg; kwargs...)
dA .= zero(dA)
return arg_darg
end
end
end

@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
# compute primal
D_ = Mooncake.primal(D_dD)
dD_ = Mooncake.tangent(D_dD)
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
nD, V = eig_full(A, alg_dalg.primal; kwargs...)

# update tangent
tmp = V \ dA
dD .= diagview(tmp * V)
dA .= zero(eltype(dA))
return Mooncake.Dual(nD.diag, dD_)
end

function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
D_ = Mooncake.primal(D_dD)
dD_ = Mooncake.tangent(D_dD)
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
dA .= zero(eltype(dA))
# update primal
DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
V = DV[2]
dD .= zero(eltype(D))
function deig_vals_adjoint(::Mooncake.NoRData)
PΔV = V' \ Diagonal(dD)
if eltype(dA) <: Real
ΔAc = PΔV * V'
dA .+= real.(ΔAc)
else
mul!(dA, PΔV, V', 1, 0)
end
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
return Mooncake.CoDual(DV[1].diag, dD_), deig_vals_adjoint
end
#=
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_full!), AbstractMatrix, Tuple{<:Diagonal, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::CoDual{<:AbstractMatrix}, DV_dDV::CoDual{<:Tuple{<:Diagonal, <:AbstractMatrix}}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
A, dA = arrayify(A_dA)
dA .= zero(eltype(A))
DV = Mooncake.primal(DV_dDV)
dDV = Mooncake.tangent(DV_dDV)
D, dD = arrayify(DV[1], dDV[1])
V, dV = arrayify(DV[2], dDV[2])
function deigh_adjoint(::Mooncake.NoRData)
dA = MatrixAlgebraKit.eigh_pullback!(dA, A, (D, V), (dD, dV); kwargs...)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
DV = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...)
return Mooncake.CoDual(DV, dDV), deigh_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eigh_full!), AbstractMatrix, Tuple{<:Diagonal, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::Dual, DV_dDV::Dual, alg_dalg::Dual; kwargs...)
A, dA = arrayify(A_dA)
DV = Mooncake.primal(DV_dDV)
dDV = Mooncake.tangent(DV_dDV)
D, dD = arrayify(DV[1], dDV[1])
V, dV = arrayify(DV[2], dDV[2])
(D, V) = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...)
(dD, dV) = eigh_pushforward!(dA, A, (D, V), (dD, dV); kwargs...)
return Mooncake.Dual(DV, dDV)
end
=#

@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
# compute primal
D_ = Mooncake.primal(D_dD)
dD_ = Mooncake.tangent(D_dD)
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
nD, V = eigh_full(A, alg_dalg.primal; kwargs...)
# update tangent
tmp = inv(V) * dA * V
dD .= real.(diagview(tmp))
D .= nD.diag
dA .= zero(eltype(dA))
return D_dD
end

function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
D_ = Mooncake.primal(D_dD)
dD_ = Mooncake.tangent(D_dD)
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
function deigh_vals_adjoint(::Mooncake.NoRData)
mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
return Mooncake.CoDual(DV[1].diag, dD_), deigh_vals_adjoint
end


for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
A, dA = arrayify(A_dA)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
function dsvd_adjoint(::Mooncake.NoRData)
dA .= zero(eltype(A))
minmn = min(size(A)...)
if size(U, 2) == size(Vᴴ, 1) == minmn # compact
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
vU = view(U, :, 1:minmn)
vS = Diagonal(diagview(S)[1:minmn])
vVᴴ = view(Vᴴ, 1:minmn, :)
vdU = view(dU, :, 1:minmn)
vdS = view(dS, 1:minmn, 1:minmn)
vdVᴴ = view(dVᴴ, 1:minmn, :)
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
end
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
dU .= zero(dU)
dS .= zero(dS)
dVᴴ .= zero(dVᴴ)
return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof($f)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual; kwargs...)
# compute primal
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
$f(A, USVᴴ, alg_dalg.primal; kwargs...)

# update tangents
U_, S_, Vᴴ_ = USVᴴ
dU_, dS_, dVᴴ_ = dUSVᴴ
U, dU = arrayify(U_, dU_)
S, dS = arrayify(S_, dS_)
Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
minmn = min(size(A)...)
if ($f == svd_compact!) # compact
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...)
else # full
vU = view(U, :, 1:minmn)
vS = view(S, 1:minmn, 1:minmn)
vVᴴ = view(Vᴴ, 1:minmn, :)
vdU = view(dU, :, 1:minmn)
vdS = view(dS, 1:minmn, 1:minmn)
vdVᴴ = view(dVᴴ, 1:minmn, :)
svd_pushforward!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ); kwargs...)
end
return USVᴴ_dUSVᴴ
end
end
end

@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual; kwargs...)
# compute primal
S_ = Mooncake.primal(S_dS)
dS_ = Mooncake.tangent(S_dS)
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)

# update tangent
S, dS = arrayify(S_, dS_)
copyto!(dS, diag(real.(Vᴴ * dA' * U)))
copyto!(S, diagview(nS))
dA .= zero(eltype(dA))
return Mooncake.Dual(nS.diag, dS)
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
S_ = Mooncake.primal(S_dS)
dS_ = Mooncake.tangent(S_dS)
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
S, dS = arrayify(S_, dS_)
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
S .= diagview(nS)
dS .= zero(eltype(S))
function dsvd_vals_adjoint(::Mooncake.NoRData)
dA .= U * Diagonal(dS) * Vᴴ
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
return S_dS, dsvd_vals_adjoint
end

end
7 changes: 7 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,11 @@ include("pullbacks/eigh.jl")
include("pullbacks/svd.jl")
include("pullbacks/polar.jl")

include("pushforwards/qr.jl")
include("pushforwards/lq.jl")
include("pushforwards/eig.jl")
include("pushforwards/eigh.jl")
include("pushforwards/polar.jl")
include("pushforwards/svd.jl")

end
2 changes: 1 addition & 1 deletion src/common/view.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# diagind: provided by LinearAlgebra.jl
diagview(D::Diagonal) = D.diag
diagview(D::Diagonal) = D.diag
diagview(D::AbstractMatrix) = view(D, diagind(D))

# triangularind
Expand Down
2 changes: 1 addition & 1 deletion src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real =
end

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
check_hermitian(A, alg)
#check_hermitian(A, alg)
D, V = DV
m = size(A, 1)
@assert D isa Diagonal && V isa AbstractMatrix
Expand Down
3 changes: 2 additions & 1 deletion src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ function eig_pullback!(
Δgauge < gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
VᴴΔV ./= conj.(transpose(D) .- D)
diagview(VᴴΔV) .= zero(eltype(VᴴΔV))

if !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
Expand Down
Loading
Loading