11const RelNormModes = Union{
2- RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode}
2+ RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode
3+ }
34const AbsNormModes = Union{
4- AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode}
5+ AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode
6+ }
57
68# Core Implementation
79@concrete mutable struct NonlinearTerminationModeCache{uType, T}
3234
3335function CommonSolve. init (
3436 :: AbstractNonlinearProblem , mode:: AbstractNonlinearTerminationMode , du, u,
35- saved_value_prototype... ; abstol = nothing , reltol = nothing , kwargs... )
37+ saved_value_prototype... ; abstol = nothing , reltol = nothing , kwargs...
38+ )
3639 T = promote_type (eltype (du), eltype (u))
3740 abstol = get_tolerance (u, abstol, T)
3841 reltol = get_tolerance (u, reltol, T)
@@ -77,12 +80,14 @@ function CommonSolve.init(
7780 return NonlinearTerminationModeCache (
7881 u_unaliased, ReturnCode. Default, abstol, reltol, best_value, mode,
7982 initial_objective, objectives_trace, 0 , saved_value_prototype,
80- u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache)
83+ u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache
84+ )
8185end
8286
8387function SciMLBase. reinit! (
8488 cache:: NonlinearTerminationModeCache , du, u, saved_value_prototype... ;
85- abstol = cache. abstol, reltol = cache. reltol, kwargs... )
89+ abstol = cache. abstol, reltol = cache. reltol, kwargs...
90+ )
8691 T = eltype (cache. abstol)
8792 length (saved_value_prototype) != 0 && (cache. saved_values = saved_value_prototype)
8893
113118
114119# # This dispatch is needed based on how Terminating Callback works!
115120function (cache:: NonlinearTerminationModeCache )(
116- integrator:: AbstractODEIntegrator , abstol:: Number , reltol:: Number , min_t)
121+ integrator:: AbstractODEIntegrator , abstol:: Number , reltol:: Number , min_t
122+ )
117123 if min_t === nothing || integrator. t ≥ min_t
118124 return cache (cache. mode, SciMLBase. get_du (integrator),
119125 integrator. u, integrator. uprev, abstol, reltol)
@@ -125,7 +131,8 @@ function (cache::NonlinearTerminationModeCache)(du, u, uprev, args...)
125131end
126132
127133function (cache:: NonlinearTerminationModeCache )(
128- mode:: AbstractNonlinearTerminationMode , du, u, uprev, abstol, reltol, args... )
134+ mode:: AbstractNonlinearTerminationMode , du, u, uprev, abstol, reltol, args...
135+ )
129136 if check_convergence (mode, du, u, uprev, abstol, reltol)
130137 cache. retcode = ReturnCode. Success
131138 return true
@@ -134,7 +141,8 @@ function (cache::NonlinearTerminationModeCache)(
134141end
135142
136143function (cache:: NonlinearTerminationModeCache )(
137- mode:: AbstractSafeNonlinearTerminationMode , du, u, uprev, abstol, reltol, args... )
144+ mode:: AbstractSafeNonlinearTerminationMode , du, u, uprev, abstol, reltol, args...
145+ )
138146 if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
139147 objective = Utils. apply_norm (mode. internalnorm, du)
140148 criteria = abstol
@@ -251,15 +259,17 @@ end
251259# High-Level API with defaults.
252260# # This is mostly for internal usage in NonlinearSolve and SimpleNonlinearSolve
253261function default_termination_mode (
254- :: Union{ImmutableNonlinearProblem, NonlinearProblem} , :: Val{:simple} )
262+ :: Union{ImmutableNonlinearProblem, NonlinearProblem} , :: Val{:simple}
263+ )
255264 return AbsNormTerminationMode (Base. Fix1 (maximum, abs))
256265end
257266function default_termination_mode (:: NonlinearLeastSquaresProblem , :: Val{:simple} )
258267 return AbsNormTerminationMode (Base. Fix2 (norm, 2 ))
259268end
260269
261270function default_termination_mode (
262- :: Union{ImmutableNonlinearProblem, NonlinearProblem} , :: Val{:regular} )
271+ :: Union{ImmutableNonlinearProblem, NonlinearProblem} , :: Val{:regular}
272+ )
263273 return AbsNormSafeBestTerminationMode (Base. Fix1 (maximum, abs); max_stalled_steps = 32 )
264274end
265275
@@ -268,16 +278,53 @@ function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:regular
268278end
269279
270280function init_termination_cache (
271- prob:: AbstractNonlinearProblem , abstol, reltol, du, u, :: Nothing , callee:: Val )
281+ prob:: AbstractNonlinearProblem , abstol, reltol, du, u, :: Nothing , callee:: Val
282+ )
272283 return init_termination_cache (
273284 prob, abstol, reltol, du, u, default_termination_mode (prob, callee), callee)
274285end
275286
276287function init_termination_cache (prob:: AbstractNonlinearProblem , abstol, reltol, du,
277- u, tc:: AbstractNonlinearTerminationMode , :: Val )
288+ u, tc:: AbstractNonlinearTerminationMode , :: Val
289+ )
278290 T = promote_type (eltype (du), eltype (u))
279291 abstol = get_tolerance (u, abstol, T)
280292 reltol = get_tolerance (u, reltol, T)
281293 cache = init (prob, tc, du, u; abstol, reltol)
282294 return abstol, reltol, cache
283295end
296+
297+ function check_and_update! (cache, fu, u, uprev)
298+ return check_and_update! (
299+ cache. termination_cache, cache, fu, u, uprev, cache. termination_cache. mode
300+ )
301+ end
302+
303+ function check_and_update! (tc_cache, cache, fu, u, uprev, mode)
304+ if tc_cache (fu, u, uprev)
305+ cache. retcode = tc_cache. retcode
306+ update_from_termination_cache! (tc_cache, cache, mode, u)
307+ cache. force_stop = true
308+ end
309+ end
310+
311+ function update_from_termination_cache! (tc_cache, cache, u = get_u (cache))
312+ return update_from_termination_cache! (tc_cache, cache, tc_cache. mode, u)
313+ end
314+
315+ function update_from_termination_cache! (
316+ tc_cache, cache, :: AbstractNonlinearTerminationMode , u = get_u (cache)
317+ )
318+ Utils. evaluate_f! (cache, u, cache. p)
319+ end
320+
321+ function update_from_termination_cache! (
322+ tc_cache, cache, :: AbstractSafeBestNonlinearTerminationMode , u = get_u (cache)
323+ )
324+ if SciMLBase. isinplace (cache)
325+ copyto! (get_u (cache), tc_cache. u)
326+ else
327+ SciMLBase. set_u! (cache, tc_cache. u)
328+ end
329+ Utils. evaluate_f! (cache, get_u (cache), cache. p)
330+ end
0 commit comments