Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit b127990

Browse files
Merge pull request #38 from CCsimon123/main
Adding a Brent method
2 parents a220ae4 + 1068d63 commit b127990

File tree

3 files changed

+143
-3
lines changed

3 files changed

+143
-3
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ include("broyden.jl")
2323
include("klement.jl")
2424
include("trustRegion.jl")
2525
include("ridder.jl")
26+
include("brent.jl")
2627
include("ad.jl")
2728

2829
import SnoopPrecompile
@@ -44,12 +45,13 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
4445
=#
4546

4647
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
47-
for alg in (Bisection, Falsi, Ridder)
48+
for alg in (Bisection, Falsi, Ridder, Brent)
4849
solve(prob_brack, alg(), abstol = T(1e-2))
4950
end
5051
end end
5152

5253
# DiffEq styled algorithms
53-
export Bisection, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion
54+
export Bisection, Brent, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson,
55+
SimpleTrustRegion
5456

5557
end # module

src/brent.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
`Brent()`
3+
4+
A non-allocating Brent method
5+
6+
"""
7+
struct Brent <: AbstractBracketingAlgorithm end
8+
9+
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
10+
maxiters = 1000,
11+
kwargs...)
12+
f = Base.Fix2(prob.f, prob.p)
13+
a, b = prob.tspan
14+
fa, fb = f(a), f(b)
15+
ϵ = eps(convert(typeof(fa), 1.0))
16+
17+
if iszero(fa)
18+
return SciMLBase.build_solution(prob, alg, a, fa;
19+
retcode = ReturnCode.ExactSolutionLeft, left = a,
20+
right = b)
21+
end
22+
if abs(fa) < abs(fb)
23+
c = b
24+
b = a
25+
a = c
26+
tmp = fa
27+
fa = fb
28+
fb = tmp
29+
end
30+
31+
c = a
32+
d = c
33+
i = 1
34+
cond = true
35+
if !iszero(fb)
36+
while i < maxiters
37+
fc = f(c)
38+
if fa != fc && fb != fc
39+
# Inverse quadratic interpolation
40+
s = a * fb * fc / ((fa - fb) * (fa - fc)) +
41+
b * fa * fc / ((fb - fa) * (fb - fc)) +
42+
c * fa * fb / ((fc - fa) * (fc - fb))
43+
else
44+
# Secant method
45+
s = b - fb * (b - a) / (fb - fa)
46+
end
47+
if (s < min((3 * a + b) / 4, b) || s > max((3 * a + b) / 4, b)) ||
48+
(cond && abs(s - b) abs(b - c) / 2) ||
49+
(!cond && abs(s - b) abs(c - d) / 2) ||
50+
(cond && abs(b - c) ϵ) ||
51+
(!cond && abs(c - d) ϵ)
52+
# Bisection method
53+
s = (a + b) / 2
54+
(s == a || s == b) &&
55+
return SciMLBase.build_solution(prob, alg, a, fa;
56+
retcode = ReturnCode.FloatingPointLimit,
57+
left = a, right = b)
58+
cond = true
59+
else
60+
cond = false
61+
end
62+
fs = f(s)
63+
if iszero(fs)
64+
if b < a
65+
a = b
66+
fa = fb
67+
end
68+
b = s
69+
fb = fs
70+
break
71+
end
72+
if fa * fs < 0
73+
d = c
74+
c = b
75+
b = s
76+
fb = fs
77+
else
78+
a = s
79+
fa = fs
80+
end
81+
if abs(fa) < abs(fb)
82+
d = c
83+
c = b
84+
b = a
85+
a = c
86+
fc = fb
87+
fb = fa
88+
fa = fc
89+
end
90+
i += 1
91+
end
92+
end
93+
94+
while i < maxiters
95+
c = (a + b) / 2
96+
if (c == a || c == b)
97+
return SciMLBase.build_solution(prob, alg, a, fa;
98+
retcode = ReturnCode.FloatingPointLimit,
99+
left = a, right = b)
100+
end
101+
fc = f(c)
102+
if iszero(fc)
103+
b = c
104+
fb = fc
105+
else
106+
a = c
107+
fa = fc
108+
end
109+
i += 1
110+
end
111+
112+
return SciMLBase.build_solution(prob, alg, a, fa; retcode = ReturnCode.MaxIters,
113+
left = a, right = b)
114+
end

test/basictests.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,22 @@ for p in 1.1:0.1:100.0
121121
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
122122
end
123123

124+
# Brent
125+
g = function (p)
126+
probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p)
127+
sol = solve(probN, Brent())
128+
return sol.left
129+
end
130+
131+
for p in 1.1:0.1:100.0
132+
@test g(p) sqrt(p)
133+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
134+
end
135+
124136
f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
125137
t = (p) -> [sqrt(p[2] / p[1])]
126138
p = [0.9, 50.0]
127-
for alg in [Bisection(), Falsi(), Ridder()]
139+
for alg in [Bisection(), Falsi(), Ridder(), Brent()]
128140
global g, p
129141
g = function (p)
130142
probN = IntervalNonlinearProblem{false}(f, tspan, p)
@@ -200,6 +212,18 @@ probB = IntervalNonlinearProblem(f, tspan)
200212
sol = solve(probB, Ridder())
201213
@test sol.left sqrt(2.0)
202214

215+
# Brent
216+
sol = solve(probB, Brent())
217+
@test sol.left sqrt(2.0)
218+
tspan = (sqrt(2.0), 10.0)
219+
probB = IntervalNonlinearProblem(f, tspan)
220+
sol = solve(probB, Brent())
221+
@test sol.left sqrt(2.0)
222+
tspan = (0.0, sqrt(2.0))
223+
probB = IntervalNonlinearProblem(f, tspan)
224+
sol = solve(probB, Brent())
225+
@test sol.left sqrt(2.0)
226+
203227
# Garuntee Tests for Bisection
204228
f = function (u, p)
205229
if u < 2.0

0 commit comments

Comments
 (0)