@@ -2,22 +2,22 @@ module NonlinearSolveBaseForwardDiffExt
22
33using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
44using ArrayInterface: ArrayInterface
5- using CommonSolve: CommonSolve, solve
5+ using CommonSolve: CommonSolve, solve, solve!, init
66using ConcreteStructs: @concrete
77using DifferentiationInterface: DifferentiationInterface
88using FastClosures: @closure
99using ForwardDiff: ForwardDiff, Dual
1010using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
1111 NonlinearProblem, NonlinearLeastSquaresProblem, remake
1212
13- using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
14- AbstractNonlinearSolveAlgorithm, Utils, InternalAPI ,
15- AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm
13+ using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI,
14+ AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm ,
15+ NonlinearSolveForwardDiffCache
1616
1717const DI = DifferentiationInterface
1818
1919const GENERAL_SOLVER_TYPES = [
20- Nothing, AbstractNonlinearSolveAlgorithm, NonlinearSolvePolyAlgorithm
20+ Nothing, NonlinearSolvePolyAlgorithm
2121]
2222
2323const DualNonlinearProblem = NonlinearProblem{
@@ -135,24 +135,16 @@ for algType in GENERAL_SOLVER_TYPES
135135 end
136136end
137137
138- @concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
139- cache
140- prob
141- alg
142- p
143- values_p
144- partials_p
145- end
146-
147138function InternalAPI. reinit! (
148139 cache:: NonlinearSolveForwardDiffCache , args... ;
149140 p = cache. p, u0 = NonlinearSolveBase. get_u (cache. cache), kwargs...
150141)
151142 InternalAPI. reinit! (
152- cache. cache; p = nodual_value (p), u0 = nodual_value (u0), kwargs...
143+ cache. cache; p = NonlinearSolveBase. nodual_value (p),
144+ u0 = NonlinearSolveBase. nodual_value (u0), kwargs...
153145 )
154146 cache. p = p
155- cache. values_p = nodual_value (p)
147+ cache. values_p = NonlinearSolveBase . nodual_value (p)
156148 cache. partials_p = ForwardDiff. partials (p)
157149 return cache
158150end
@@ -161,8 +153,8 @@ for algType in GENERAL_SOLVER_TYPES
161153 @eval function SciMLBase. __init (
162154 prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs...
163155 )
164- p = nodual_value (prob. p)
165- newprob = SciMLBase. remake (prob; u0 = nodual_value (prob. u0), p)
156+ p = NonlinearSolveBase . nodual_value (prob. p)
157+ newprob = SciMLBase. remake (prob; u0 = NonlinearSolveBase . nodual_value (prob. u0), p)
166158 cache = init (newprob, alg, args... ; kwargs... )
167159 return NonlinearSolveForwardDiffCache (
168160 cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p)
@@ -196,8 +188,17 @@ function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
196188 )
197189end
198190
199- nodual_value (x) = x
200- nodual_value (x:: Dual ) = ForwardDiff. value (x)
201- nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
191+ NonlinearSolveBase. nodual_value (x) = x
192+ NonlinearSolveBase. nodual_value (x:: Dual ) = ForwardDiff. value (x)
193+ NonlinearSolveBase. nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
194+
195+ """
196+ pickchunksize(x) = pickchunksize(length(x))
197+ pickchunksize(x::Int)
198+
199+ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
200+ """
201+ @inline NonlinearSolveBase. pickchunksize (x) = pickchunksize (length (x))
202+ @inline NonlinearSolveBase. pickchunksize (x:: Int ) = ForwardDiff. pickchunksize (x)
202203
203204end
0 commit comments