Skip to content

Commit 55d8df5

Browse files
committed
refactor fit!(::LinearMixedModel) to support a more flexible optimization backend
1 parent 1ea4083 commit 55d8df5

File tree

6 files changed

+184
-137
lines changed

6 files changed

+184
-137
lines changed

src/MixedModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ include("grouping.jl")
205205
include("mimeshow.jl")
206206
include("serialization.jl")
207207
include("profile/profile.jl")
208+
include("nlopt.jl")
208209
include("prima.jl")
209210

210211
# COV_EXCL_START

src/linearmixedmodel.jl

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ function LinearMixedModel(
177177
θ = foldl(vcat, getθ(c) for c in reterms)
178178
optsum = OptSummary(θ, lbd)
179179
optsum.sigma = isnothing(σ) ? nothing : T(σ)
180-
fill!(optsum.xtol_abs, 1.0e-10)
181180
return LinearMixedModel(
182181
form,
183182
reterms,
@@ -459,55 +458,10 @@ function StatsAPI.fit!(
459458
ArgumentError("The response is constant and thus model fitting has failed")
460459
)
461460
end
462-
opt = Opt(optsum)
463461
optsum.REML = REML
464462
optsum.sigma = σ
465-
prog = ProgressUnknown(; desc="Minimizing", showspeed=true)
466-
# start from zero for the initial call to obj before optimization
467-
iter = 0
463+
xmin, fmin = optimize!(m; progress, thin)
468464
fitlog = optsum.fitlog
469-
function obj(x, g)
470-
isempty(g) || throw(ArgumentError("g should be empty for this objective"))
471-
iter += 1
472-
val = if isone(iter) && x == optsum.initial
473-
optsum.finitial
474-
else
475-
try
476-
objective(updateL!(setθ!(m, x)))
477-
catch ex
478-
# This can happen when the optimizer drifts into an area where
479-
# there isn't enough shrinkage. Why finitial? Generally, it will
480-
# be the (near) worst case scenario value, so the optimizer won't
481-
# view it as an optimum. Using Inf messes up the quadratic
482-
# approximation in BOBYQA.
483-
ex isa PosDefException || rethrow()
484-
optsum.finitial
485-
end
486-
end
487-
progress && ProgressMeter.next!(prog; showvalues=[(:objective, val)])
488-
!isone(iter) && iszero(rem(iter, thin)) && push!(fitlog, (copy(x), val))
489-
return val
490-
end
491-
NLopt.min_objective!(opt, obj)
492-
try
493-
# use explicit evaluation w/o calling opt to avoid confusing iteration count
494-
optsum.finitial = objective(updateL!(setθ!(m, optsum.initial)))
495-
catch ex
496-
ex isa PosDefException || rethrow()
497-
# give it one more try with a massive change in scaling
498-
@info "Initial objective evaluation failed, rescaling initial guess and trying again."
499-
@warn """Failure of the initial evaluation is often indicative of a model specification
500-
that is not well supported by the data and/or a poorly scaled model.
501-
"""
502-
optsum.initial ./=
503-
(isempty(m.sqrtwts) ? 1.0 : maximum(m.sqrtwts)^2) *
504-
maximum(response(m))
505-
optsum.finitial = objective(updateL!(setθ!(m, optsum.initial)))
506-
end
507-
empty!(fitlog)
508-
push!(fitlog, (copy(optsum.initial), optsum.finitial))
509-
fmin, xmin, ret = NLopt.optimize!(opt, copyto!(optsum.final, optsum.initial))
510-
ProgressMeter.finish!(prog)
511465
## check if small non-negative parameter values can be set to zero
512466
xmin_ = copy(xmin)
513467
lb = optsum.lowerbd
@@ -518,7 +472,7 @@ function StatsAPI.fit!(
518472
end
519473
loglength = length(fitlog)
520474
if xmin xmin_
521-
if (zeroobj = obj(xmin_, T[])) (fmin + optsum.ftol_zero_abs)
475+
if (zeroobj = objective!(m, xmin_)) (fmin + optsum.ftol_zero_abs)
522476
fmin = zeroobj
523477
copyto!(xmin, xmin_)
524478
elseif length(fitlog) > loglength
@@ -529,11 +483,8 @@ function StatsAPI.fit!(
529483
## ensure that the parameter values saved in m are xmin
530484
updateL!(setθ!(m, xmin))
531485

532-
optsum.feval = opt.numevals
533486
optsum.final = xmin
534487
optsum.fmin = fmin
535-
optsum.returnvalue = ret
536-
_check_nlopt_return(ret)
537488
return m
538489
end
539490

@@ -828,25 +779,6 @@ function objective(m::LinearMixedModel{T}) where {T}
828779
return isempty(wts) ? val : val - T(2.0) * sum(log, wts)
829780
end
830781

831-
"""
832-
objective!(m::LinearMixedModel, θ)
833-
objective!(m::LinearMixedModel)
834-
835-
Equivalent to `objective(updateL!(setθ!(m, θ)))`.
836-
837-
When `m` has a single, scalar random-effects term, `θ` can be a scalar.
838-
839-
The one-argument method curries and returns a single-argument function of `θ`.
840-
841-
Note that these methods modify `m`.
842-
The calling function is responsible for restoring the optimal `θ`.
843-
"""
844-
function objective! end
845-
846-
function objective!(m::LinearMixedModel)
847-
return Base.Fix1(objective!, m)
848-
end
849-
850782
function objective!(m::LinearMixedModel{T}, θ) where {T}
851783
return objective(updateL!(setθ!(m, θ)))
852784
end

src/mixedmodel.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,25 @@ StatsAPI.modelmatrix(m::MixedModel) = m.feterm.x
117117

118118
StatsAPI.nobs(m::MixedModel) = length(m.y)
119119

120+
"""
121+
objective!(m::MixedModel, θ)
122+
objective!(m::MixedModel)
123+
124+
Equivalent to `objective(updateL!(setθ!(m, θ)))`.
125+
126+
When `m` has a single, scalar random-effects term, `θ` can be a scalar.
127+
128+
The one-argument method curries and returns a single-argument function of `θ`.
129+
130+
Note that these methods modify `m`.
131+
The calling function is responsible for restoring the optimal `θ`.
132+
"""
133+
function objective! end
134+
135+
function objective!(m::MixedModel)
136+
return Base.Fix1(objective!, m)
137+
end
138+
120139
StatsAPI.predict(m::MixedModel) = fitted(m)
121140

122141
function retbl(mat, trm)
@@ -131,13 +150,13 @@ StatsAPI.adjr2(m::MixedModel) = r2(m)
131150
function StatsAPI.r2(m::MixedModel)
132151
@error (
133152
"""There is no uniquely defined coefficient of determination for mixed models
134-
that has all the properties of the corresponding value for classical
153+
that has all the properties of the corresponding value for classical
135154
linear models. The GLMM FAQ provides more detail:
136-
155+
137156
https://bbolker.github.io/mixedmodels-misc/glmmFAQ.html#how-do-i-compute-a-coefficient-of-determination-r2-or-an-analogue-for-glmms
138157
139158
140-
Alternatively, MixedModelsExtras provides a naive implementation, but
159+
Alternatively, MixedModelsExtras provides a naive implementation, but
141160
the warnings there and in the FAQ should be taken seriously!
142161
"""
143162
)
@@ -148,7 +167,7 @@ end
148167
raneftables(m::MixedModel; uscale = false)
149168
150169
Return the conditional means of the random effects as a `NamedTuple` of Tables.jl-compliant tables.
151-
170+
152171
!!! note
153172
The API guarantee is only that the NamedTuple contains Tables.jl tables and not on the particular concrete type of each table.
154173
"""

src/nlopt.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
push!(OPTIMIZATION_BACKENDS, :nlopt)
2+
3+
const NLoptBackend = Val{:nlopt}
4+
5+
function optimize!(m::LinearMixedModel, ::NLoptBackend; progress::Bool=true, thin::Int=tyepmax(Int))
6+
optsum = m.optsum
7+
opt = Opt(optsum)
8+
9+
function obj(x, g)
10+
isempty(g) || throw(ArgumentError("g should be empty for this objective"))
11+
iter += 1
12+
val = if isone(iter) && x == optsum.initial
13+
optsum.finitial
14+
else
15+
try
16+
objective!(m, x)
17+
catch ex
18+
# This can happen when the optimizer drifts into an area where
19+
# there isn't enough shrinkage. Why finitial? Generally, it will
20+
# be the (near) worst case scenario value, so the optimizer won't
21+
# view it as an optimum. Using Inf messes up the quadratic
22+
# approximation in BOBYQA.
23+
ex isa PosDefException || rethrow()
24+
optsum.finitial
25+
end
26+
end
27+
progress && ProgressMeter.next!(prog; showvalues=[(:objective, val)])
28+
!isone(iter) && iszero(rem(iter, thin)) && push!(fitlog, (copy(x), val))
29+
return val
30+
end
31+
NLopt.min_objective!(opt, obj)
32+
prog = ProgressUnknown(; desc="Minimizing", showspeed=true)
33+
# start from zero for the initial call to obj before optimization
34+
iter = 0
35+
fitlog = optsum.fitlog
36+
37+
try
38+
# use explicit evaluation w/o calling opt to avoid confusing iteration count
39+
optsum.finitial = objective!(m, optsum.initial)
40+
catch ex
41+
ex isa PosDefException || rethrow()
42+
# give it one more try with a massive change in scaling
43+
@info "Initial objective evaluation failed, rescaling initial guess and trying again."
44+
@warn """Failure of the initial evaluation is often indicative of a model specification
45+
that is not well supported by the data and/or a poorly scaled model.
46+
"""
47+
optsum.initial ./=
48+
(isempty(m.sqrtwts) ? 1.0 : maximum(m.sqrtwts)^2) *
49+
maximum(response(m))
50+
optsum.finitial = objective!(m, optsum.initial)
51+
end
52+
empty!(fitlog)
53+
push!(fitlog, (copy(optsum.initial), optsum.finitial))
54+
fmin, xmin, ret = NLopt.optimize!(opt, copyto!(optsum.final, optsum.initial))
55+
ProgressMeter.finish!(prog)
56+
optsum.feval = opt.numevals
57+
optsum.returnvalue = ret
58+
_check_nlopt_return(ret)
59+
return xmin, fmin
60+
end
61+
62+
function NLopt.Opt(optsum::OptSummary)
63+
lb = optsum.lowerbd
64+
65+
opt = NLopt.Opt(optsum.optimizer, length(lb))
66+
NLopt.ftol_rel!(opt, optsum.ftol_rel) # relative criterion on objective
67+
NLopt.ftol_abs!(opt, optsum.ftol_abs) # absolute criterion on objective
68+
NLopt.xtol_rel!(opt, optsum.xtol_rel) # relative criterion on parameter values
69+
if length(optsum.xtol_abs) == length(lb) # not true for fast=false optimization in GLMM
70+
NLopt.xtol_abs!(opt, optsum.xtol_abs) # absolute criterion on parameter values
71+
end
72+
NLopt.lower_bounds!(opt, lb)
73+
NLopt.maxeval!(opt, optsum.maxfeval)
74+
NLopt.maxtime!(opt, optsum.maxtime)
75+
if isempty(optsum.initial_step)
76+
optsum.initial_step = NLopt.initial_step(opt, optsum.initial, similar(lb))
77+
else
78+
NLopt.initial_step!(opt, optsum.initial_step)
79+
end
80+
return opt
81+
end
82+
83+
84+
const _NLOPT_FAILURE_MODES = [
85+
:FAILURE,
86+
:INVALID_ARGS,
87+
:OUT_OF_MEMORY,
88+
:FORCED_STOP,
89+
:MAXEVAL_REACHED,
90+
:MAXTIME_REACHED,
91+
]
92+
93+
function _check_nlopt_return(ret, failure_modes=_NLOPT_FAILURE_MODES)
94+
ret == :ROUNDOFF_LIMITED && @warn("NLopt was roundoff limited")
95+
if ret failure_modes
96+
@warn("NLopt optimization failure: $ret")
97+
end
98+
end

0 commit comments

Comments
 (0)