-
-
Notifications
You must be signed in to change notification settings - Fork 99
DAE optimizers added to OptimizationODE #932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
a6ae41e
99d06ef
b7e2927
36be3e8
1cd41ca
7385194
ad2f53f
9a623d1
b759ce7
47a9d61
1f358be
2b03b36
36a8590
953fda2
0691c24
17204c4
11bd665
7658f9b
526a8a9
5e359d1
e13cbca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,58 +2,189 @@ module OptimizationODE | |||||
|
||||||
using Reexport | ||||||
@reexport using Optimization, SciMLBase | ||||||
using OrdinaryDiffEq, SteadyStateDiffEq | ||||||
using LinearAlgebra, ForwardDiff | ||||||
|
||||||
using NonlinearSolve | ||||||
using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials | ||||||
|
||||||
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent | ||||||
export DAEOptimizer, DAEMassMatrix, DAEIndexing | ||||||
|
||||||
struct ODEOptimizer{T} | ||||||
solver::T | ||||||
end | ||||||
|
||||||
ODEGradientDescent() = ODEOptimizer(Euler()) | ||||||
RKChebyshevDescent() = ODEOptimizer(ROCK2()) | ||||||
RKAccelerated() = ODEOptimizer(Tsit5()) | ||||||
HighOrderDescent() = ODEOptimizer(Vern7()) | ||||||
|
||||||
struct ODEOptimizer{T, T2} | ||||||
struct DAEOptimizer{T} | ||||||
solver::T | ||||||
dt::T2 | ||||||
end | ||||||
ODEOptimizer(solver ; dt=nothing) = ODEOptimizer(solver, dt) | ||||||
|
||||||
# Solver Constructors (users call these) | ||||||
ODEGradientDescent(; dt) = ODEOptimizer(Euler(); dt) | ||||||
RKChebyshevDescent() = ODEOptimizer(ROCK2()) | ||||||
RKAccelerated() = ODEOptimizer(Tsit5()) | ||||||
HighOrderDescent() = ODEOptimizer(Vern7()) | ||||||
DAEMassMatrix() = DAEOptimizer(Rodas5()) | ||||||
DAEIndexing() = DAEOptimizer(IDA()) | ||||||
|
||||||
|
||||||
SciMLBase.requiresbounds(::ODEOptimizer) = false | ||||||
SciMLBase.allowsbounds(::ODEOptimizer) = false | ||||||
SciMLBase.allowscallback(::ODEOptimizer) = true | ||||||
SciMLBase.requiresbounds(::ODEOptimizer) = false | ||||||
SciMLBase.allowsbounds(::ODEOptimizer) = false | ||||||
SciMLBase.allowscallback(::ODEOptimizer) = true | ||||||
SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true | ||||||
SciMLBase.requiresgradient(::ODEOptimizer) = true | ||||||
SciMLBase.requireshessian(::ODEOptimizer) = false | ||||||
SciMLBase.requiresconsjac(::ODEOptimizer) = false | ||||||
SciMLBase.requiresconshess(::ODEOptimizer) = false | ||||||
SciMLBase.requiresgradient(::ODEOptimizer) = true | ||||||
SciMLBase.requireshessian(::ODEOptimizer) = false | ||||||
SciMLBase.requiresconsjac(::ODEOptimizer) = false | ||||||
SciMLBase.requiresconshess(::ODEOptimizer) = false | ||||||
|
||||||
|
||||||
SciMLBase.requiresbounds(::DAEOptimizer) = false | ||||||
SciMLBase.allowsbounds(::DAEOptimizer) = false | ||||||
SciMLBase.allowsconstraints(::DAEOptimizer) = true | ||||||
SciMLBase.allowscallback(::DAEOptimizer) = true | ||||||
SciMLBase.supports_opt_cache_interface(::DAEOptimizer) = true | ||||||
SciMLBase.requiresgradient(::DAEOptimizer) = true | ||||||
SciMLBase.requireshessian(::DAEOptimizer) = false | ||||||
SciMLBase.requiresconsjac(::DAEOptimizer) = true | ||||||
SciMLBase.requiresconshess(::DAEOptimizer) = false | ||||||
|
||||||
|
||||||
function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer; | ||||||
callback=Optimization.DEFAULT_CALLBACK, progress=false, | ||||||
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing, | ||||||
maxiters=nothing, kwargs...) | ||||||
|
||||||
return OptimizationCache(prob, opt; callback=callback, progress=progress, | ||||||
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt, | ||||||
maxiters=maxiters, kwargs...) | ||||||
end | ||||||
|
||||||
function SciMLBase.__solve( | ||||||
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C} | ||||||
) where {F,RC,LB,UB,LC,UC,S,O<:ODEOptimizer,D,P,C} | ||||||
function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer; | ||||||
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing, | ||||||
maxiters=nothing, differential_vars=nothing, kwargs...) | ||||||
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt, | ||||||
maxiters=maxiters, differential_vars=differential_vars, kwargs...) | ||||||
end | ||||||
|
||||||
|
||||||
function solve_constrained_root(cache, u0, p) | ||||||
n = length(u0) | ||||||
cons_vals = cache.f.cons(u0, p) | ||||||
m = length(cons_vals) | ||||||
function resid!(res, u) | ||||||
temp = similar(u) | ||||||
f_mass!(temp, u, p, 0.0) | ||||||
res .= temp | ||||||
end | ||||||
u0_ext = vcat(u0, zeros(m)) | ||||||
prob_nl = NonlinearProblem(resid!, u0_ext, p) | ||||||
sol_nl = solve(prob_nl, Newton(); tol = 1e-8, maxiters = 100000, | ||||||
callback = cache.callback, progress = get(cache.solver_args, :progress, false)) | ||||||
u_ext = sol_nl.u | ||||||
return u_ext[1:n], sol_nl.retcode | ||||||
end | ||||||
|
||||||
|
||||||
function get_solver_type(opt::DAEOptimizer) | ||||||
if opt.solver isa Union{Rodas5, RadauIIA5, ImplicitEuler, Trapezoid} | ||||||
return :mass_matrix | ||||||
else | ||||||
return :indexing | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been removed in the latest commits. |
||||||
end | ||||||
|
||||||
dt = cache.opt.dt | ||||||
maxit = get(cache.solver_args, :maxiters, 1000) | ||||||
function handle_parameters(p) | ||||||
if p isa SciMLBase.NullParameters | ||||||
return Float64[] | ||||||
else | ||||||
return p | ||||||
end | ||||||
end | ||||||
|
||||||
function setup_progress_callback(cache, solve_kwargs) | ||||||
if get(cache.solver_args, :progress, false) | ||||||
condition = (u, t, integrator) -> true | ||||||
affect! = (integrator) -> begin | ||||||
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u | ||||||
cache.solver_args[:callback](u_opt, integrator.p, integrator.t) | ||||||
end | ||||||
cb = DiscreteCallback(condition, affect!) | ||||||
solve_kwargs[:callback] = cb | ||||||
end | ||||||
return solve_kwargs | ||||||
end | ||||||
|
||||||
function finite_difference_jacobian(f, x; ϵ = 1e-8) | ||||||
n = length(x) | ||||||
fx = f(x) | ||||||
if fx === nothing | ||||||
return zeros(eltype(x), 0, n) | ||||||
elseif isa(fx, Number) | ||||||
J = zeros(eltype(fx), 1, n) | ||||||
for j in 1:n | ||||||
xj = copy(x) | ||||||
xj[j] += ϵ | ||||||
diff = f(xj) | ||||||
if diff === nothing | ||||||
diffval = zero(eltype(fx)) | ||||||
else | ||||||
diffval = diff - fx | ||||||
end | ||||||
J[1, j] = diffval / ϵ | ||||||
end | ||||||
return J | ||||||
else | ||||||
m = length(fx) | ||||||
J = zeros(eltype(fx), m, n) | ||||||
for j in 1:n | ||||||
xj = copy(x) | ||||||
xj[j] += ϵ | ||||||
fxj = f(xj) | ||||||
if fxj === nothing | ||||||
@inbounds for i in 1:m | ||||||
J[i, j] = -fx[i] / ϵ | ||||||
end | ||||||
else | ||||||
@inbounds for i in 1:m | ||||||
J[i, j] = (fxj[i] - fx[i]) / ϵ | ||||||
end | ||||||
end | ||||||
end | ||||||
return J | ||||||
end | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary |
||||||
|
||||||
function SciMLBase.__solve( | ||||||
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C} | ||||||
) where {F,RC,LB,UB,LC,UC,S,O<:Union{ODEOptimizer,DAEOptimizer},D,P,C} | ||||||
|
||||||
dt = get(cache.solver_args, :dt, nothing) | ||||||
maxit = get(cache.solver_args, :maxiters, nothing) | ||||||
differential_vars = get(cache.solver_args, :differential_vars, nothing) | ||||||
u0 = copy(cache.u0) | ||||||
p = cache.p | ||||||
p = handle_parameters(cache.p) # Properly handle NullParameters | ||||||
|
||||||
if cache.opt isa ODEOptimizer | ||||||
return solve_ode(cache, dt, maxit, u0, p) | ||||||
else | ||||||
solver_method = get_solver_type(cache.opt) | ||||||
if solver_method == :mass_matrix | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The actual check is |
||||||
return solve_dae_mass_matrix(cache, dt, maxit, u0, p) | ||||||
else | ||||||
return solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume you mean implicit? |
||||||
end | ||||||
end | ||||||
end | ||||||
|
||||||
function solve_ode(cache, dt, maxit, u0, p) | ||||||
if cache.f.grad === nothing | ||||||
error("ODEOptimizer requires a gradient. Please provide a function with `grad` defined.") | ||||||
end | ||||||
|
||||||
function f!(du, u, p, t) | ||||||
cache.f.grad(du, u, p) | ||||||
@. du = -du | ||||||
grad_vec = similar(u) | ||||||
if isempty(p) | ||||||
cache.f.grad(grad_vec, u) | ||||||
else | ||||||
cache.f.grad(grad_vec, u, p) | ||||||
end | ||||||
@. du = -grad_vec | ||||||
return nothing | ||||||
end | ||||||
|
||||||
|
@@ -62,14 +193,11 @@ function SciMLBase.__solve( | |||||
algorithm = DynamicSS(cache.opt.solver) | ||||||
|
||||||
cb = cache.callback | ||||||
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) === true | ||||||
function condition(u, t, integrator) | ||||||
true | ||||||
end | ||||||
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) | ||||||
function condition(u, t, integrator) true end | ||||||
function affect!(integrator) | ||||||
u_now = integrator.u | ||||||
state = Optimization.OptimizationState(u=u_now, objective=cache.f(integrator.u, integrator.p)) | ||||||
Optimization.callback_function(cb, state) | ||||||
cache.callback(u_now, integrator.p, integrator.t) | ||||||
end | ||||||
cb_struct = DiscreteCallback(condition, affect!) | ||||||
callback = CallbackSet(cb_struct) | ||||||
|
@@ -86,21 +214,154 @@ function SciMLBase.__solve( | |||||
end | ||||||
|
||||||
sol = solve(ss_prob, algorithm; solve_kwargs...) | ||||||
has_destats = hasproperty(sol, :destats) | ||||||
has_t = hasproperty(sol, :t) && !isempty(sol.t) | ||||||
has_destats = hasproperty(sol, :destats) | ||||||
has_t = hasproperty(sol, :t) && !isempty(sol.t) | ||||||
|
||||||
stats = Optimization.OptimizationStats( | ||||||
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10), | ||||||
time = has_t ? sol.t[end] : 0.0, | ||||||
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0, | ||||||
gevals = has_destats ? get(sol.destats, :iters, 0) : 0, | ||||||
hevals = 0 | ||||||
) | ||||||
stats = Optimization.OptimizationStats( | ||||||
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10), | ||||||
time = has_t ? sol.t[end] : 0.0, | ||||||
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0, | ||||||
gevals = has_destats ? get(sol.destats, :iters, 0) : 0, | ||||||
hevals = 0 | ||||||
) | ||||||
|
||||||
SciMLBase.build_solution(cache, cache.opt, sol.u, cache.f(sol.u, p); | ||||||
retcode = ReturnCode.Success, | ||||||
stats = stats | ||||||
) | ||||||
end | ||||||
|
||||||
function solve_dae_mass_matrix(cache, dt, maxit, u0, p) | ||||||
if cache.f.cons === nothing | ||||||
return solve_ode(cache, dt, maxit, u0, p) | ||||||
end | ||||||
x=u0 | ||||||
cons_vals = cache.f.cons(x, p) | ||||||
n = length(u0) | ||||||
m = length(cons_vals) | ||||||
u0_extended = vcat(u0, zeros(m)) | ||||||
M = zeros(n + m, n + m) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just make M a diagonal matrix |
||||||
M[1:n, 1:n] = I(n) | ||||||
|
||||||
function f_mass!(du, u, p_, t) | ||||||
x = @view u[1:n] | ||||||
λ = @view u[n+1:end] | ||||||
grad_f = similar(x) | ||||||
if cache.f.grad !== nothing | ||||||
cache.f.grad(grad_f, x, p_) | ||||||
else | ||||||
grad_f .= ForwardDiff.gradient(z -> cache.f.f(z, p_), x) | ||||||
end | ||||||
J = Matrix{eltype(x)}(undef, m, n) | ||||||
if cache.f.cons_j !== nothing | ||||||
cache.f.cons_j(J, x) | ||||||
else | ||||||
J .= finite_difference_jacobian(z -> cache.f.cons(z, p_), x) | ||||||
end | ||||||
@. du[1:n] = -grad_f - (J' * λ) | ||||||
consv = cache.f.cons(x, p_) | ||||||
if consv === nothing | ||||||
fill!(du[n+1:end], zero(eltype(x))) | ||||||
else | ||||||
if isa(consv, Number) | ||||||
@assert m == 1 | ||||||
du[n+1] = consv | ||||||
else | ||||||
@assert length(consv) == m | ||||||
@. du[n+1:end] = consv | ||||||
end | ||||||
end | ||||||
return nothing | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary, it's just 0 |
||||||
|
||||||
if m == 0 | ||||||
optf = ODEFunction(f_mass!, mass_matrix = I(n)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
prob = ODEProblem(optf, u0, (0.0, 1.0), p) | ||||||
return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the solver being specified? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when falling back to ode, I thought it better to use the best method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not necessarily the best method. |
||||||
end | ||||||
|
||||||
ss_prob = SteadyStateProblem(ODEFunction(f_mass!, mass_matrix = M), u0_extended, p) | ||||||
|
||||||
solve_kwargs = setup_progress_callback(cache, Dict()) | ||||||
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end | ||||||
if dt !== nothing; solve_kwargs[:dt] = dt; end | ||||||
|
||||||
sol = solve(ss_prob, DynamicSS(cache.opt.solver); solve_kwargs...) | ||||||
# if sol.retcode ≠ ReturnCode.Success | ||||||
# # you may still accept Default or warn | ||||||
# end | ||||||
u_ext = sol.u | ||||||
u_final = u_ext[1:n] | ||||||
return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p); | ||||||
retcode = sol.retcode) | ||||||
end | ||||||
|
||||||
|
||||||
function solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars) | ||||||
if cache.f.cons === nothing | ||||||
return solve_ode(cache, dt, maxit, u0, p) | ||||||
end | ||||||
x=u0 | ||||||
cons_vals = cache.f.cons(x, p) | ||||||
n = length(u0) | ||||||
m = length(cons_vals) | ||||||
u0_ext = vcat(u0, zeros(m)) | ||||||
du0_ext = zeros(n + m) | ||||||
|
||||||
if differential_vars === nothing | ||||||
differential_vars = vcat(fill(true, n), fill(false, m)) | ||||||
else | ||||||
if length(differential_vars) == n | ||||||
differential_vars = vcat(differential_vars, fill(false, m)) | ||||||
elseif length(differential_vars) == n + m | ||||||
# use as is | ||||||
else | ||||||
error("differential_vars length must be number of variables ($n) or extended size ($(n+m))") | ||||||
end | ||||||
end | ||||||
|
||||||
function dae_residual!(res, du, u, p_, t) | ||||||
x = @view u[1:n] | ||||||
λ = @view u[n+1:end] | ||||||
du_x = @view du[1:n] | ||||||
grad_f = similar(x) | ||||||
cache.f.grad(grad_f, x, p_) | ||||||
J = zeros(m, n) | ||||||
if cache.f.cons_j !== nothing | ||||||
cache.f.cons_j(J, x) | ||||||
else | ||||||
J .= finite_difference_jacobian(z -> cache.f.cons(z,p_), x) | ||||||
end | ||||||
@. res[1:n] = du_x + grad_f + J' * λ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just zero = cons for the other terms. |
||||||
consv = cache.f.cons(x, p_) | ||||||
@. res[n+1:end] = consv | ||||||
return nothing | ||||||
end | ||||||
|
||||||
if m == 0 | ||||||
optf = ODEFunction(dae_residual!, differential_vars = differential_vars) | ||||||
prob = ODEProblem(optf, du0_ext, (0.0, 1.0), p) | ||||||
return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit) | ||||||
end | ||||||
|
||||||
tspan = (0.0, 10.0) | ||||||
prob = DAEProblem(dae_residual!, du0_ext, u0_ext, tspan, p; | ||||||
differential_vars = differential_vars) | ||||||
|
||||||
solve_kwargs = setup_progress_callback(cache, Dict()) | ||||||
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end | ||||||
if dt !== nothing; solve_kwargs[:dt] = dt; end | ||||||
if hasfield(typeof(cache.opt.solver), :initializealg) | ||||||
solve_kwargs[:initializealg] = BrownFullBasicInit() | ||||||
end | ||||||
|
||||||
sol = solve(prob, cache.opt.solver; solve_kwargs...) | ||||||
u_ext = sol.u | ||||||
u_final = u_ext[end][1:n] | ||||||
|
||||||
return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p); | ||||||
retcode = sol.retcode) | ||||||
end | ||||||
|
||||||
|
||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this about?