Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit a220ae4

Browse files
Merge pull request #37 from CCsimon123/main
Implementation of Ridder
2 parents 178e186 + 6b15a20 commit a220ae4

File tree

3 files changed

+109
-3
lines changed

3 files changed

+109
-3
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ include("raphson.jl")
2222
include("broyden.jl")
2323
include("klement.jl")
2424
include("trustRegion.jl")
25+
include("ridder.jl")
2526
include("ad.jl")
2627

2728
import SnoopPrecompile
@@ -43,12 +44,12 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
4344
=#
4445

4546
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
46-
for alg in (Bisection, Falsi)
47+
for alg in (Bisection, Falsi, Ridder)
4748
solve(prob_brack, alg(), abstol = T(1e-2))
4849
end
4950
end end
5051

5152
# DiffEq styled algorithms
52-
export Bisection, Broyden, Falsi, Klement, SimpleNewtonRaphson, SimpleTrustRegion
53+
export Bisection, Broyden, Falsi, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion
5354

5455
end # module

src/ridder.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
`Ridder()`
3+
4+
A non-allocating ridder method
5+
6+
"""
7+
struct Ridder <: AbstractBracketingAlgorithm end
8+
9+
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
10+
maxiters = 1000,
11+
kwargs...)
12+
f = Base.Fix2(prob.f, prob.p)
13+
left, right = prob.tspan
14+
fl, fr = f(left), f(right)
15+
16+
if iszero(fl)
17+
return SciMLBase.build_solution(prob, alg, left, fl;
18+
retcode = ReturnCode.ExactSolutionLeft, left = left,
19+
right = right)
20+
end
21+
22+
xo = oftype(left, Inf)
23+
i = 1
24+
if !iszero(fr)
25+
while i < maxiters
26+
mid = (left + right) / 2
27+
(mid == left || mid == right) &&
28+
return SciMLBase.build_solution(prob, alg, left, fl;
29+
retcode = ReturnCode.FloatingPointLimit,
30+
left = left, right = right)
31+
fm = f(mid)
32+
s = sqrt(fm^2 - fl * fr)
33+
iszero(s) &&
34+
return SciMLBase.build_solution(prob, alg, left, fl;
35+
retcode = ReturnCode.Failure,
36+
left = left, right = right)
37+
x = mid + (mid - left) * sign(fl - fr) * fm / s
38+
fx = f(x)
39+
xo = x
40+
if iszero(fx)
41+
right = x
42+
fr = fx
43+
break
44+
end
45+
if sign(fx) != sign(fm)
46+
left = mid
47+
fl = fm
48+
right = x
49+
fr = fx
50+
elseif sign(fx) != sign(fl)
51+
right = x
52+
fr = fx
53+
else
54+
@assert sign(fx) != sign(fr)
55+
left = x
56+
fl = fx
57+
end
58+
i += 1
59+
end
60+
end
61+
62+
while i < maxiters
63+
mid = (left + right) / 2
64+
(mid == left || mid == right) &&
65+
return SciMLBase.build_solution(prob, alg, left, fl;
66+
retcode = ReturnCode.FloatingPointLimit,
67+
left = left, right = right)
68+
fm = f(mid)
69+
if iszero(fm)
70+
right = mid
71+
fr = fm
72+
else
73+
left = mid
74+
fl = fm
75+
end
76+
i += 1
77+
end
78+
79+
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
80+
left = left, right = right)
81+
end

test/basictests.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,22 @@ for p in 1.1:0.1:100.0
109109
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
110110
end
111111

112+
# Ridder
113+
g = function (p)
114+
probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p)
115+
sol = solve(probN, Ridder())
116+
return sol.left
117+
end
118+
119+
for p in 1.1:0.1:100.0
120+
@test g(p) sqrt(p)
121+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
122+
end
123+
112124
f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
113125
t = (p) -> [sqrt(p[2] / p[1])]
114126
p = [0.9, 50.0]
115-
for alg in [Bisection(), Falsi()]
127+
for alg in [Bisection(), Falsi(), Ridder()]
116128
global g, p
117129
g = function (p)
118130
probN = IntervalNonlinearProblem{false}(f, tspan, p)
@@ -176,6 +188,18 @@ sol = solve(probB, Falsi())
176188
sol = solve(probB, Bisection())
177189
@test sol.left sqrt(2.0)
178190

191+
# Ridder
192+
sol = solve(probB, Ridder())
193+
@test sol.left sqrt(2.0)
194+
tspan = (sqrt(2.0), 10.0)
195+
probB = IntervalNonlinearProblem(f, tspan)
196+
sol = solve(probB, Ridder())
197+
@test sol.left sqrt(2.0)
198+
tspan = (0.0, sqrt(2.0))
199+
probB = IntervalNonlinearProblem(f, tspan)
200+
sol = solve(probB, Ridder())
201+
@test sol.left sqrt(2.0)
202+
179203
# Garuntee Tests for Bisection
180204
f = function (u, p)
181205
if u < 2.0

0 commit comments

Comments
 (0)