-
Notifications
You must be signed in to change notification settings - Fork 5
Reverse rules for Enzyme #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
... and 31 files with indirect coverage changes 🚀 New features to boost your workflow:
|
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
index e18cc41..8ef0b4e 100644
--- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
+++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
@@ -198,8 +198,8 @@ function EnzymeRules.augmented_primal(
svd_compact!(A.val, USVᴴ.val, alg.val.alg)
cache_USVᴴ = copy.(USVᴴ.val)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
- ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
- primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
+ ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
+ primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const)
dU, dS, dVᴴ = USVᴴ.dval
# This creates new output shadow matrices, we do this slicing
@@ -207,8 +207,8 @@ function EnzymeRules.augmented_primal(
# These new shadow matrices are "filled in" with the accumulated
# results from earlier in reverse-mode AD after this function exits
# and before `reverse` is called.
- dStrunc = Diagonal(diagview(dS)[ind])
- dUtrunc = dU[:, ind]
+ dStrunc = Diagonal(diagview(dS)[ind])
+ dUtrunc = dU[:, ind]
dVᴴtrunc = dVᴴ[ind, :]
(dUtrunc, dStrunc, dVᴴtrunc)
else
@@ -227,9 +227,9 @@ function EnzymeRules.reverse(
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT}
cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
- U, S, Vᴴ = cache_USVᴴ
+ U, S, Vᴴ = cache_USVᴴ
dU, dS, dVᴴ = shadow_USVᴴ
- Aval = isnothing(cache_A) ? A.val : cache_A
+ Aval = isnothing(cache_A) ? A.val : cache_A
if !isa(A, Const) && !isa(USVᴴ, Const)
svd_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
end
@@ -246,14 +246,14 @@ function EnzymeRules.augmented_primal(
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
)
# form cache if needed
- cache_A = copy(A.val)
+ cache_A = copy(A.val)
U, S, Vᴴ, ϵ = svd_trunc(A.val, USVᴴ.val, alg.val.alg)
- primal = EnzymeRules.needs_primal(config) ? (U, S, Vᴴ, ϵ) : nothing
- dU = zero(U)
- dS = zero(S)
+ primal = EnzymeRules.needs_primal(config) ? (U, S, Vᴴ, ϵ) : nothing
+ dU = zero(U)
+ dS = zero(S)
dVᴴ = zero(Vᴴ)
- dϵ = zero(ϵ)
- shadow = EnzymeRules.needs_shadow(config) ? (dU, dS, dVᴴ, dϵ) : nothing
+ dϵ = zero(ϵ)
+ shadow = EnzymeRules.needs_shadow(config) ? (dU, dS, dVᴴ, dϵ) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, (U, S, Vᴴ), (dU, dS, dVᴴ)))
end
function EnzymeRules.reverse(
@@ -265,9 +265,9 @@ function EnzymeRules.reverse(
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
)
cache_A, cache_USVᴴ, shadow_USVᴴ = cache
- U, S, Vᴴ = cache_USVᴴ
+ U, S, Vᴴ = cache_USVᴴ
dU, dS, dVᴴ = shadow_USVᴴ
- Aval = isnothing(cache_A) ? A.val : cache_A
+ Aval = isnothing(cache_A) ? A.val : cache_A
if !isa(A, Const) && !isa(USVᴴ, Const)
svd_trunc_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
end
diff --git a/test/enzyme.jl b/test/enzyme.jl
index 208f456..7c257bd 100644
--- a/test/enzyme.jl
+++ b/test/enzyme.jl
@@ -397,7 +397,7 @@ end
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
# broken due to Enzyme -- copying in gaugefix????
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
- test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
+ test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act = RT)
end
U, S, Vᴴ = svd_compact(A)
ΔU = randn(rng, T, m, minmn)
@@ -416,7 +416,7 @@ end
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
# broken due to Enzyme
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
- test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
+ test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act = RT)
end
end
end |
4408849 to
ec8354b
Compare
| func::Const{typeof($f)}, | ||
| ::Type{RT}, | ||
| A::Annotation{<:AbstractMatrix}, | ||
| arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove AbstractMatrix restrictions here and elsewhere
| ) where {RT} | ||
| cache_arg = nothing | ||
| # form cache if needed | ||
| cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this safe for the ! methods? Those should be tested directly
7478c90 to
9091c56
Compare
No description provided.