Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 303 additions & 42 deletions lib/OptimizationODE/src/OptimizationODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,189 @@ module OptimizationODE

using Reexport
@reexport using Optimization, SciMLBase
using OrdinaryDiffEq, SteadyStateDiffEq
using LinearAlgebra, ForwardDiff

using NonlinearSolve
using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials

export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
export DAEOptimizer, DAEMassMatrix, DAEIndexing

struct ODEOptimizer{T}
solver::T
end

ODEGradientDescent() = ODEOptimizer(Euler())
RKChebyshevDescent() = ODEOptimizer(ROCK2())
RKAccelerated() = ODEOptimizer(Tsit5())
HighOrderDescent() = ODEOptimizer(Vern7())

struct ODEOptimizer{T, T2}
struct DAEOptimizer{T}
solver::T
dt::T2
end
ODEOptimizer(solver ; dt=nothing) = ODEOptimizer(solver, dt)

# Solver Constructors (users call these)
ODEGradientDescent(; dt) = ODEOptimizer(Euler(); dt)
RKChebyshevDescent() = ODEOptimizer(ROCK2())
RKAccelerated() = ODEOptimizer(Tsit5())
HighOrderDescent() = ODEOptimizer(Vern7())
DAEMassMatrix() = DAEOptimizer(Rodas5())
DAEIndexing() = DAEOptimizer(IDA())


SciMLBase.requiresbounds(::ODEOptimizer) = false
SciMLBase.allowsbounds(::ODEOptimizer) = false
SciMLBase.allowscallback(::ODEOptimizer) = true
SciMLBase.requiresbounds(::ODEOptimizer) = false
SciMLBase.allowsbounds(::ODEOptimizer) = false
SciMLBase.allowscallback(::ODEOptimizer) = true
SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true
SciMLBase.requiresgradient(::ODEOptimizer) = true
SciMLBase.requireshessian(::ODEOptimizer) = false
SciMLBase.requiresconsjac(::ODEOptimizer) = false
SciMLBase.requiresconshess(::ODEOptimizer) = false
SciMLBase.requiresgradient(::ODEOptimizer) = true
SciMLBase.requireshessian(::ODEOptimizer) = false
SciMLBase.requiresconsjac(::ODEOptimizer) = false
SciMLBase.requiresconshess(::ODEOptimizer) = false


SciMLBase.requiresbounds(::DAEOptimizer) = false
SciMLBase.allowsbounds(::DAEOptimizer) = false
SciMLBase.allowsconstraints(::DAEOptimizer) = true
SciMLBase.allowscallback(::DAEOptimizer) = true
SciMLBase.supports_opt_cache_interface(::DAEOptimizer) = true
SciMLBase.requiresgradient(::DAEOptimizer) = true
SciMLBase.requireshessian(::DAEOptimizer) = false
SciMLBase.requiresconsjac(::DAEOptimizer) = true
SciMLBase.requiresconshess(::DAEOptimizer) = false


function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer;
callback=Optimization.DEFAULT_CALLBACK, progress=false,
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing,
maxiters=nothing, kwargs...)

return OptimizationCache(prob, opt; callback=callback, progress=progress,
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt,
maxiters=maxiters, kwargs...)
end

function SciMLBase.__solve(
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
) where {F,RC,LB,UB,LC,UC,S,O<:ODEOptimizer,D,P,C}
function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer;
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing,
maxiters=nothing, differential_vars=nothing, kwargs...)
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt,
maxiters=maxiters, differential_vars=differential_vars, kwargs...)
end


function solve_constrained_root(cache, u0, p)
n = length(u0)
cons_vals = cache.f.cons(u0, p)
m = length(cons_vals)
function resid!(res, u)
temp = similar(u)
f_mass!(temp, u, p, 0.0)
res .= temp
end
u0_ext = vcat(u0, zeros(m))
prob_nl = NonlinearProblem(resid!, u0_ext, p)
sol_nl = solve(prob_nl, Newton(); tol = 1e-8, maxiters = 100000,
callback = cache.callback, progress = get(cache.solver_args, :progress, false))
u_ext = sol_nl.u
return u_ext[1:n], sol_nl.retcode
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this about?



function get_solver_type(opt::DAEOptimizer)
if opt.solver isa Union{Rodas5, RadauIIA5, ImplicitEuler, Trapezoid}
return :mass_matrix
else
return :indexing
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use opt.solver isa DAEAlgorithm instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been removed in the latest commits.

end

dt = cache.opt.dt
maxit = get(cache.solver_args, :maxiters, 1000)
function handle_parameters(p)
if p isa SciMLBase.NullParameters
return Float64[]
else
return p
end
end

function setup_progress_callback(cache, solve_kwargs)
if get(cache.solver_args, :progress, false)
condition = (u, t, integrator) -> true
affect! = (integrator) -> begin
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
cache.solver_args[:callback](u_opt, integrator.p, integrator.t)
end
cb = DiscreteCallback(condition, affect!)
solve_kwargs[:callback] = cb
end
return solve_kwargs
end

function finite_difference_jacobian(f, x; ϵ = 1e-8)
n = length(x)
fx = f(x)
if fx === nothing
return zeros(eltype(x), 0, n)
elseif isa(fx, Number)
J = zeros(eltype(fx), 1, n)
for j in 1:n
xj = copy(x)
xj[j] += ϵ
diff = f(xj)
if diff === nothing
diffval = zero(eltype(fx))
else
diffval = diff - fx
end
J[1, j] = diffval / ϵ
end
return J
else
m = length(fx)
J = zeros(eltype(fx), m, n)
for j in 1:n
xj = copy(x)
xj[j] += ϵ
fxj = f(xj)
if fxj === nothing
@inbounds for i in 1:m
J[i, j] = -fx[i] / ϵ
end
else
@inbounds for i in 1:m
J[i, j] = (fxj[i] - fx[i]) / ϵ
end
end
end
return J
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary


function SciMLBase.__solve(
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
) where {F,RC,LB,UB,LC,UC,S,O<:Union{ODEOptimizer,DAEOptimizer},D,P,C}

dt = get(cache.solver_args, :dt, nothing)
maxit = get(cache.solver_args, :maxiters, nothing)
differential_vars = get(cache.solver_args, :differential_vars, nothing)
u0 = copy(cache.u0)
p = cache.p
p = handle_parameters(cache.p) # Properly handle NullParameters

if cache.opt isa ODEOptimizer
return solve_ode(cache, dt, maxit, u0, p)
else
solver_method = get_solver_type(cache.opt)
if solver_method == :mass_matrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual check is alg isa DAEAlgorithm which means implicit, otherwise mass matrix

return solve_dae_mass_matrix(cache, dt, maxit, u0, p)
else
return solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you mean implicit?

end
end
end

function solve_ode(cache, dt, maxit, u0, p)
if cache.f.grad === nothing
error("ODEOptimizer requires a gradient. Please provide a function with `grad` defined.")
end

function f!(du, u, p, t)
cache.f.grad(du, u, p)
@. du = -du
grad_vec = similar(u)
if isempty(p)
cache.f.grad(grad_vec, u)
else
cache.f.grad(grad_vec, u, p)
end
@. du = -grad_vec
return nothing
end

Expand All @@ -62,14 +193,11 @@ function SciMLBase.__solve(
algorithm = DynamicSS(cache.opt.solver)

cb = cache.callback
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) === true
function condition(u, t, integrator)
true
end
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false)
function condition(u, t, integrator) true end
function affect!(integrator)
u_now = integrator.u
state = Optimization.OptimizationState(u=u_now, objective=cache.f(integrator.u, integrator.p))
Optimization.callback_function(cb, state)
cache.callback(u_now, integrator.p, integrator.t)
end
cb_struct = DiscreteCallback(condition, affect!)
callback = CallbackSet(cb_struct)
Expand All @@ -86,21 +214,154 @@ function SciMLBase.__solve(
end

sol = solve(ss_prob, algorithm; solve_kwargs...)
has_destats = hasproperty(sol, :destats)
has_t = hasproperty(sol, :t) && !isempty(sol.t)
has_destats = hasproperty(sol, :destats)
has_t = hasproperty(sol, :t) && !isempty(sol.t)

stats = Optimization.OptimizationStats(
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10),
time = has_t ? sol.t[end] : 0.0,
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0,
gevals = has_destats ? get(sol.destats, :iters, 0) : 0,
hevals = 0
)
stats = Optimization.OptimizationStats(
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10),
time = has_t ? sol.t[end] : 0.0,
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0,
gevals = has_destats ? get(sol.destats, :iters, 0) : 0,
hevals = 0
)

SciMLBase.build_solution(cache, cache.opt, sol.u, cache.f(sol.u, p);
retcode = ReturnCode.Success,
stats = stats
)
end

function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
if cache.f.cons === nothing
return solve_ode(cache, dt, maxit, u0, p)
end
x=u0
cons_vals = cache.f.cons(x, p)
n = length(u0)
m = length(cons_vals)
u0_extended = vcat(u0, zeros(m))
M = zeros(n + m, n + m)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make M a diagonal matrix

M[1:n, 1:n] = I(n)

function f_mass!(du, u, p_, t)
x = @view u[1:n]
λ = @view u[n+1:end]
grad_f = similar(x)
if cache.f.grad !== nothing
cache.f.grad(grad_f, x, p_)
else
grad_f .= ForwardDiff.gradient(z -> cache.f.f(z, p_), x)
end
J = Matrix{eltype(x)}(undef, m, n)
if cache.f.cons_j !== nothing
cache.f.cons_j(J, x)
else
J .= finite_difference_jacobian(z -> cache.f.cons(z, p_), x)
end
@. du[1:n] = -grad_f - (J' * λ)
consv = cache.f.cons(x, p_)
if consv === nothing
fill!(du[n+1:end], zero(eltype(x)))
else
if isa(consv, Number)
@assert m == 1
du[n+1] = consv
else
@assert length(consv) == m
@. du[n+1:end] = consv
end
end
return nothing
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary, it's just 0


if m == 0
optf = ODEFunction(f_mass!, mass_matrix = I(n))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
optf = ODEFunction(f_mass!, mass_matrix = I(n))
optf = ODEFunction(f_mass!)

prob = ODEProblem(optf, u0, (0.0, 1.0), p)
return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the solver being specified?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when falling back to ode, I thought it better to use the best method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not necessarily the best method.

end

ss_prob = SteadyStateProblem(ODEFunction(f_mass!, mass_matrix = M), u0_extended, p)

solve_kwargs = setup_progress_callback(cache, Dict())
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end
if dt !== nothing; solve_kwargs[:dt] = dt; end

sol = solve(ss_prob, DynamicSS(cache.opt.solver); solve_kwargs...)
# if sol.retcode ≠ ReturnCode.Success
# # you may still accept Default or warn
# end
u_ext = sol.u
u_final = u_ext[1:n]
return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p);
retcode = sol.retcode)
end


function solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
if cache.f.cons === nothing
return solve_ode(cache, dt, maxit, u0, p)
end
x=u0
cons_vals = cache.f.cons(x, p)
n = length(u0)
m = length(cons_vals)
u0_ext = vcat(u0, zeros(m))
du0_ext = zeros(n + m)

if differential_vars === nothing
differential_vars = vcat(fill(true, n), fill(false, m))
else
if length(differential_vars) == n
differential_vars = vcat(differential_vars, fill(false, m))
elseif length(differential_vars) == n + m
# use as is
else
error("differential_vars length must be number of variables ($n) or extended size ($(n+m))")
end
end

function dae_residual!(res, du, u, p_, t)
x = @view u[1:n]
λ = @view u[n+1:end]
du_x = @view du[1:n]
grad_f = similar(x)
cache.f.grad(grad_f, x, p_)
J = zeros(m, n)
if cache.f.cons_j !== nothing
cache.f.cons_j(J, x)
else
J .= finite_difference_jacobian(z -> cache.f.cons(z,p_), x)
end
@. res[1:n] = du_x + grad_f + J' * λ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just zero = cons for the other terms.

consv = cache.f.cons(x, p_)
@. res[n+1:end] = consv
return nothing
end

if m == 0
optf = ODEFunction(dae_residual!, differential_vars = differential_vars)
prob = ODEProblem(optf, du0_ext, (0.0, 1.0), p)
return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit)
end

tspan = (0.0, 10.0)
prob = DAEProblem(dae_residual!, du0_ext, u0_ext, tspan, p;
differential_vars = differential_vars)

solve_kwargs = setup_progress_callback(cache, Dict())
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end
if dt !== nothing; solve_kwargs[:dt] = dt; end
if hasfield(typeof(cache.opt.solver), :initializealg)
solve_kwargs[:initializealg] = BrownFullBasicInit()
end

sol = solve(prob, cache.opt.solver; solve_kwargs...)
u_ext = sol.u
u_final = u_ext[end][1:n]

return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p);
retcode = sol.retcode)
end


end
Loading
Loading