Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 74e257b

Browse files
committed
Adding a Brent method
1 parent 6b15a20 commit 74e257b

File tree

3 files changed

+138
-3
lines changed

3 files changed

+138
-3
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ include("broyden.jl")
2323
include("klement.jl")
2424
include("trustRegion.jl")
2525
include("ridder.jl")
26+
include("brent.jl")
2627
include("ad.jl")
2728

2829
import SnoopPrecompile
@@ -44,12 +45,13 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
4445
=#
4546

4647
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
47-
for alg in (Bisection, Falsi, Ridder)
48+
for alg in (Bisection, Falsi, Ridder, Brent)
4849
solve(prob_brack, alg(), abstol = T(1e-2))
4950
end
5051
end end
5152

5253
# DiffEq styled algorithms
53-
export Bisection, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion
54+
export Bisection, Brent, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson,
55+
SimpleTrustRegion
5456

5557
end # module

src/brent.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
`Brent()`
3+
4+
A non-allocating Brent method
5+
6+
"""
7+
struct Brent <: AbstractBracketingAlgorithm end
8+
9+
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
10+
maxiters = 1000,
11+
kwargs...)
12+
f = Base.Fix2(prob.f, prob.p)
13+
a, b = prob.tspan
14+
fa, fb = f(a), f(b)
15+
16+
if iszero(fa)
17+
return SciMLBase.build_solution(prob, alg, a, fa;
18+
retcode = ReturnCode.ExactSolutionLeft, left = a,
19+
right = b)
20+
end
21+
if abs(fa) < abs(fb)
22+
c = b
23+
b = a
24+
a = c
25+
tmp = fa
26+
fa = fb
27+
fb = tmp
28+
end
29+
30+
c = a
31+
d = c
32+
i = 1
33+
cond = true
34+
if !iszero(fb)
35+
while i < maxiters
36+
fc = f(c)
37+
if fa != fc && fb != fc
38+
# Inverse quadratic interpolation
39+
s = a * fb * fc / ((fa - fb) * (fa - fc)) +
40+
b * fa * fc / ((fb - fa) * (fb - fc)) +
41+
c * fa * fb / ((fc - fa) * (fc - fb))
42+
else
43+
# Secant method
44+
s = b - fb * (b - a) / (fb - fa)
45+
end
46+
if (s < min((3 * a + b) / 4, b) || s > max((3 * a + b) / 4, b)) ||
47+
(cond && abs(s - b) abs(b - c) / 2) ||
48+
(!cond && abs(s - b) abs(c - d) / 2) ||
49+
(cond && abs(b - c) eps(a)) ||
50+
(!cond && abs(c - d) eps(a))
51+
# Bisection method
52+
s = (a + b) / 2
53+
(s == a || s == b) &&
54+
return SciMLBase.build_solution(prob, alg, a, fa;
55+
retcode = ReturnCode.FloatingPointLimit,
56+
left = a, right = b)
57+
cond = true
58+
else
59+
cond = false
60+
end
61+
fs = f(s)
62+
if iszero(fs)
63+
a = b
64+
b = s
65+
break
66+
end
67+
if fa * fs < 0
68+
d = c
69+
c = b
70+
b = s
71+
fb = fs
72+
else
73+
a = s
74+
fa = fs
75+
end
76+
if abs(fa) < abs(fb)
77+
d = c
78+
c = b
79+
b = a
80+
a = c
81+
fc = fb
82+
fb = fa
83+
fa = fc
84+
end
85+
i += 1
86+
end
87+
end
88+
89+
while i < maxiters
90+
c = (a + b) / 2
91+
if (c == a || c == b)
92+
return SciMLBase.build_solution(prob, alg, a, fa;
93+
retcode = ReturnCode.FloatingPointLimit,
94+
left = a, right = b)
95+
end
96+
fc = f(c)
97+
if iszero(fc)
98+
b = c
99+
fb = fc
100+
else
101+
a = c
102+
fa = fc
103+
end
104+
i += 1
105+
end
106+
107+
return SciMLBase.build_solution(prob, alg, a, fa; retcode = ReturnCode.MaxIters,
108+
left = a, right = b)
109+
end

test/basictests.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,22 @@ for p in 1.1:0.1:100.0
121121
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
122122
end
123123

124+
# Brent
125+
g = function (p)
126+
probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p)
127+
sol = solve(probN, Brent())
128+
return sol.left
129+
end
130+
131+
for p in 1.1:0.1:100.0
132+
@test g(p) sqrt(p)
133+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
134+
end
135+
124136
f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
125137
t = (p) -> [sqrt(p[2] / p[1])]
126138
p = [0.9, 50.0]
127-
for alg in [Bisection(), Falsi(), Ridder()]
139+
for alg in [Bisection(), Falsi(), Ridder(), Brent()]
128140
global g, p
129141
g = function (p)
130142
probN = IntervalNonlinearProblem{false}(f, tspan, p)
@@ -200,6 +212,18 @@ probB = IntervalNonlinearProblem(f, tspan)
200212
sol = solve(probB, Ridder())
201213
@test sol.left sqrt(2.0)
202214

215+
# Brent
216+
sol = solve(probB, Brent())
217+
@test sol.left sqrt(2.0)
218+
tspan = (sqrt(2.0), 10.0)
219+
probB = IntervalNonlinearProblem(f, tspan)
220+
sol = solve(probB, Brent())
221+
@test sol.left sqrt(2.0)
222+
tspan = (0.0, sqrt(2.0))
223+
probB = IntervalNonlinearProblem(f, tspan)
224+
sol = solve(probB, Brent())
225+
@test sol.left sqrt(2.0)
226+
203227
# Garuntee Tests for Bisection
204228
f = function (u, p)
205229
if u < 2.0

0 commit comments

Comments
 (0)