Skip to content

Commit 3bae173

Browse files
committed
fix adjoints
1 parent 3978e7e commit 3bae173

File tree

10 files changed

+44
-26
lines changed

10 files changed

+44
-26
lines changed

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear
2323
LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier,
2424
promote_u0, get_concrete_u0, get_concrete_p,
2525
has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem,
26-
NoDefaultAlgorithmError, NonSolverError, KeywordArgError
26+
NoDefaultAlgorithmError, NonSolverError, KeywordArgError, AbstractDEAlgorithm
2727
import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake
2828

2929
using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator

lib/NonlinearSolveBase/src/solve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -794,11 +794,11 @@ end
794794
function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
795795
kwargs...)
796796
alg = extract_alg(args, kwargs, prob.kwargs)
797-
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
798-
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,
797+
if isnothing(alg) || !(alg isa AbstractNonlinearAlgorithm) # Default algorithm handling
798+
_prob = get_concrete_problem(prob, true; u0 = u0,
799799
p = p, kwargs...)
800800
else
801-
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
801+
_prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...)
802802
end
803803

804804
if has_kwargs(_prob)

lib/NonlinearSolveBase/src/termination_conditions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function CommonSolve.init(
8282
length(saved_value_prototype) == 0 && (saved_value_prototype = nothing)
8383

8484
leastsq = typeof(prob) <: NonlinearLeastSquaresProblem
85-
85+
Main.@infiltrate
8686
return NonlinearTerminationModeCache(
8787
u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode,
8888
initial_objective, objectives_trace, 0, saved_value_prototype,

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@ module SimpleNonlinearSolveChainRulesCoreExt
22

33
using ChainRulesCore: ChainRulesCore, NoTangent
44

5-
using NonlinearSolveBase: ImmutableNonlinearProblem
5+
using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint
66
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem
77

8-
using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up,
9-
solve_adjoint
8+
using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up
109

1110
function ChainRulesCore.rrule(
1211
::typeof(simplenonlinearsolve_solve_up),
1312
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
1413
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...
1514
)
1615
out,
17-
∇internal = solve_adjoint(
16+
∇internal = _solve_adjoint(
1817
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...
1918
)
2019
function ∇simplenonlinearsolve_solve_up(Δ)

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ module SimpleNonlinearSolveDiffEqBaseExt
22

33
#using DiffEqBase: DiffEqBase
44

5-
using SimpleNonlinearSolve: SimpleNonlinearSolve
5+
# using SimpleNonlinearSolve: SimpleNonlinearSolve
66

7-
SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true
7+
# SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true
88

9-
function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...)
10-
return DiffEqBase._solve_adjoint(args...; kwargs...)
11-
end
9+
# function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...)
10+
# return DiffEqBase._solve_adjoint(args...; kwargs...)
11+
# end
1212

1313
end

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module SimpleNonlinearSolveReverseDiffExt
22

3-
using NonlinearSolveBase: ImmutableNonlinearProblem
3+
using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint
44
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake
55

66
using ArrayInterface: ArrayInterface
77
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
88

9-
using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
9+
using SimpleNonlinearSolve: SimpleNonlinearSolve
1010
import SimpleNonlinearSolve: simplenonlinearsolve_solve_up
1111

1212
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
@@ -27,7 +27,7 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
2727
u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp)
2828
prob = remake(tprob; u0, p)
2929
out,
30-
∇internal = solve_adjoint(
30+
∇internal = _solve_adjoint(
3131
prob, sensealg, u0, p, ReverseDiffOriginator(), alg, args...; kwargs...)
3232

3333
function ∇simplenonlinearsolve_solve_up...)

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module SimpleNonlinearSolveTrackerExt
22

3-
using NonlinearSolveBase: ImmutableNonlinearProblem
3+
using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint
44
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
55

66
using ArrayInterface: ArrayInterface
77
using Tracker: Tracker, TrackedArray, TrackedReal
88

9-
using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
9+
using SimpleNonlinearSolve: SimpleNonlinearSolve
1010

1111
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
1212
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
@@ -26,7 +26,7 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
2626
u0, p = Tracker.data(tu0), Tracker.data(tp)
2727
prob = remake(tprob; u0, p)
2828
out,
29-
∇internal = solve_adjoint(
29+
∇internal = _solve_adjoint(
3030
prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...)
3131

3232
function ∇simplenonlinearsolve_solve_up(Δ)

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ end
125125

126126
# NOTE: This is defined like this so that we don't have to keep have 2 args for the
127127
# extensions
128-
function solve_adjoint(args...; kws...)
129-
is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...)
130-
error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.")
131-
end
128+
# function solve_adjoint(args...; kws...)
129+
# is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...)
130+
# error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.")
131+
# end
132132

133-
function solve_adjoint_internal end
133+
# function solve_adjoint_internal end
134134

135135
@setup_workload begin
136136
for T in (Float64,)

lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testitem "Simple Adjoint Test" tags=[:adjoint] begin
2-
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, DiffEqBase
2+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote
33

44
ff(u, p) = u .^ 2 .- p
55

test/adjoint_tests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test
2+
3+
ff(u, p) = u .^ 2 .- p
4+
5+
function solve_nlprob(p)
6+
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
7+
sol = solve(prob, NewtonRaphson())
8+
res = sol isa AbstractArray ? sol : sol.u
9+
return sum(abs2, res)
10+
end
11+
12+
p = [3.0, 2.0]
13+
14+
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
15+
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
16+
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
17+
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
18+
@test ∂p_zygote ∂p_tracker ∂p_reversediff
19+
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff

0 commit comments

Comments
 (0)