Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,32 @@
const have_not_warned_vjp = Ref(true)
const STACKTRACE_WITH_VJPWARN = Ref(false)

function adfunc(out, u, _p, t, repack)
f(out, u, repack(_p), t)
nothing
end

function inplace_vjp(prob, u0, p, verbose, repack)
du = zero(u0)

ez = try
f = unwrapped_f(prob.f)

function adfunc(out, u, _p, t)
f(out, u, repack(_p), t)
nothing
end
Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(du, copy(u0)),
Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1]))
Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1]), Enzyme.Const(repack))
true
catch e
false
end
if ez
return EnzymeVJP()
end

erz = try
f = unwrapped_f(prob.f)

Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Reverse), adfunc, Enzyme.Duplicated(du, copy(u0)),
Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1]), Enzyme.Const(repack))
true
catch e
if verbose && have_not_warned_vjp[]
Expand All @@ -28,8 +42,8 @@ function inplace_vjp(prob, u0, p, verbose, repack)
end
false
end
if ez
return EnzymeVJP()
if erz
return EnzymeVJP(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))
end

# Determine if we can compile ReverseDiff
Expand Down
30 changes: 16 additions & 14 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,16 @@ function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy,
return dy, dλ, dgrad
end

function gclosure1(f, du, u, p, t)
Base.copyto!(du, f(u, p, t))
nothing
end

function gclosure2(du, u, p, t, W)
Base.copyto!(du, f(u, p, t, W))
nothing
end

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy,
W) where {TS <: SensitivityFunction}
(; sensealg) = S
Expand Down Expand Up @@ -732,13 +742,13 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
end

if W === nothing
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6),
Enzyme.autodiff(isautojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
dup,
Enzyme.Const(t))
else
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6),
Enzyme.autodiff(isautojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
dup,
Expand All @@ -750,22 +760,14 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
dy !== nothing && recursive_copyto!(dy, tmp3)
else
if W === nothing
function g(du, u, p, t)
du .= f(u, p, t)
nothing
end
_tmp6 = Enzyme.make_zero(g)
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(g, _tmp6),
_tmp6 = Enzyme.make_zero(f)
Enzyme.autodiff(isautojacvec.mode, Enzyme.Const(gclosure1), Enzyme.Duplicated(f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
dup, Enzyme.Const(t))
else
function g(du, u, p, t, W)
du .= f(u, p, t, W)
nothing
end
_tmp6 = Enzyme.make_zero(g)
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(g, _tmp6),
_tmp6 = Enzyme.make_zero(f)
Enzyme.autodiff(isautojacvec.mode, Enzyme.Const(gclosure2), Enzyme.Duplicated(f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
dup, Enzyme.Const(t), Enzyme.Const(W))
Expand Down
15 changes: 8 additions & 7 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
sensealg, dgdp_cache, dgdp)
end

function g(f, du, u, p, t)
Base.copyto!(du, f(u, p, t))
nothing
end

# out = λ df(u, p, t)/dp at u=y, p=p, t=t
function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
(; pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol) = S
Expand Down Expand Up @@ -500,17 +505,13 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
Enzyme.remake_zero!(tmp6)

Enzyme.autodiff(
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
sensealg.autojacvec.mode, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))
else
function g(du, u, p, t)
du .= f(u, p, t)
nothing
end
tmp6 = Enzyme.make_zero(g)
tmp6 = Enzyme.make_zero(f)
Enzyme.autodiff(
Enzyme.Reverse, Enzyme.Duplicated(g, tmp6), Enzyme.Const,
sensealg.autojacvec.mode, Enzyme.Const(gclosure3), Enzyme.Duplicated(f, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))
end
Expand Down
15 changes: 8 additions & 7 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
AdjointSensitivityIntegrand(sol, adj_sol, p, y, λ, pf, f_cache, pJ, paramjac_config,
sensealg, dgdp_cache, dgdp)
end

function gclosure4(f, du, u, p, t)
Base.copyto!(du, f(u, p, t))
nothing
end

# out = λ df(u, p, t)/dp at u=y, p=p, t=t
function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
Expand Down Expand Up @@ -295,17 +300,13 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
if SciMLBase.isinplace(sol.prob.f)
Enzyme.remake_zero!(tmp6)
Enzyme.autodiff(
Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), tmp6), Enzyme.Const,
sensealg.autojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), dup, Enzyme.Const(t))
else
function g(du, u, p, t)
du .= f(u, p, t)
nothing
end
tmp6 = Enzyme.make_zero(g)
tmp6 = Enzyme.make_zero(f)
Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.Reverse), Enzyme.Duplicated(g, tmp6), Enzyme.Const,
sensealg.autojacvec.mode, Enzyme.Const(gclosure4), Enzyme.Duplicated(f, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), dup, Enzyme.Const(t))
end
Expand Down
9 changes: 6 additions & 3 deletions src/sensitivity_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,7 @@ like BLAS/LAPACK are used) and this will be the most efficient adjoint implement
## Constructor

```julia
EnzymeVJP(; chunksize = 0)
EnzymeVJP(; chunksize = 0, mode = EnzymeCore.Reverse)
```

## Keyword Arguments
Expand All @@ -1317,12 +1317,15 @@ EnzymeVJP(; chunksize = 0)
should be set to the maximum chunksize that can occur during an integration to preallocate
the `DualCaches` for PreallocationTools.jl. It defaults to 0, using `ForwardDiff.pickchunksize`
but could be decreased if this value is known to be lower to conserve memory.
- `mode`: the parameterized Enzyme mode, default set to EnzymeCore.Reverse. Alternatively one
may want to pass Enzyme.set_runtime_activity(Enzyme.Reverse)
"""
struct EnzymeVJP <: VJPChoice
struct EnzymeVJP{Mode<:Enzyme.ReverseMode} <: VJPChoice
chunksize::Int
mode::Mode
end

EnzymeVJP(; chunksize = 0) = EnzymeVJP(chunksize)
EnzymeVJP(; chunksize = 0, mode = Enzyme.Reverse) = EnzymeVJP(chunksize, mode)

"""
```julia
Expand Down
Loading