Skip to content

Commit 99d3216

Browse files
committed
feat(BracketingNonlinearSolve): subtype AbstractNonlinearSolveAlgorithm
1 parent ba8f893 commit 99d3216

File tree

8 files changed

+133
-71
lines changed

8 files changed

+133
-71
lines changed

lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
module BracketingNonlinearSolve
22

33
using ConcreteStructs: @concrete
4+
using PrecompileTools: @compile_workload, @setup_workload
45
using Reexport: @reexport
56

67
using CommonSolve: CommonSolve, solve
7-
using NonlinearSolveBase: NonlinearSolveBase
8-
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, IntervalNonlinearProblem, ReturnCode
9-
10-
using PrecompileTools: @compile_workload, @setup_workload
8+
using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm
9+
using SciMLBase: SciMLBase, IntervalNonlinearProblem, ReturnCode
1110

12-
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearAlgorithm end
11+
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
1312

1413
include("common.jl")
1514

lib/BracketingNonlinearSolve/src/alefeld.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,25 @@ algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal
88
"""
99
struct Alefeld <: AbstractBracketingAlgorithm end
1010

11-
function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...;
12-
maxiters = 1000, abstol = nothing, kwargs...)
11+
function CommonSolve.solve(
12+
prob::IntervalNonlinearProblem, alg::Alefeld, args...;
13+
maxiters = 1000, abstol = nothing, kwargs...
14+
)
1315
f = Base.Fix2(prob.f, prob.p)
1416
a, b = prob.tspan
1517
c = a - (b - a) / (f(b) - f(a)) * f(a)
1618

1719
fc = f(c)
1820
if a == c || b == c
1921
return SciMLBase.build_solution(
20-
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, left = a, right = b)
22+
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit, left = a, right = b
23+
)
2124
end
2225

2326
if iszero(fc)
2427
return SciMLBase.build_solution(
25-
prob, alg, c, fc; retcode = ReturnCode.Success, left = a, right = b)
28+
prob, alg, c, fc; retcode = ReturnCode.Success, left = a, right = b
29+
)
2630
end
2731

2832
a, b, d = Impl.bracket(f, a, b, c)
@@ -68,13 +72,15 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...
6872

6973
if== c ||== c
7074
return SciMLBase.build_solution(
71-
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
72-
left = ā, right = b̄)
75+
prob, alg, c, fc;
76+
retcode = ReturnCode.FloatingPointLimit, left = ā, right =
77+
)
7378
end
7479

7580
if iszero(fc)
7681
return SciMLBase.build_solution(
77-
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄)
82+
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right =
83+
)
7884
end
7985

8086
ā, b̄, d̄ = Impl.bracket(f, ā, b̄, c)
@@ -89,13 +95,15 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...
8995

9096
if== c ||== c
9197
return SciMLBase.build_solution(
92-
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
93-
left = ā, right = b̄)
98+
prob, alg, c, fc;
99+
retcode = ReturnCode.FloatingPointLimit, left = ā, right =
100+
)
94101
end
95102

96103
if iszero(fc)
97104
return SciMLBase.build_solution(
98-
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄)
105+
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right =
106+
)
99107
end
100108

101109
ā, b̄, d = Impl.bracket(f, ā, b̄, c)
@@ -110,12 +118,14 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...
110118

111119
if== c ||== c
112120
return SciMLBase.build_solution(
113-
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
114-
left = ā, right = b̄)
121+
prob, alg, c, fc;
122+
retcode = ReturnCode.FloatingPointLimit, left = ā, right =
123+
)
115124
end
116125
if iszero(fc)
117126
return SciMLBase.build_solution(
118-
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right = b̄)
127+
prob, alg, c, fc; retcode = ReturnCode.Success, left = ā, right =
128+
)
119129
end
120130
a, b, d = Impl.bracket(f, ā, b̄, c)
121131
end
@@ -131,5 +141,6 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...
131141

132142
# Return solution when run out of max iteration
133143
return SciMLBase.build_solution(
134-
prob, alg, c, fc; retcode = ReturnCode.MaxIters, left = a, right = b)
144+
prob, alg, c, fc; retcode = ReturnCode.MaxIters, left = a, right = b
145+
)
135146
end

lib/BracketingNonlinearSolve/src/bisection.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ A common bisection method.
1919
exact_right::Bool = false
2020
end
2121

22-
function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Bisection,
23-
args...; maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...)
22+
function CommonSolve.solve(
23+
prob::IntervalNonlinearProblem, alg::Bisection, args...;
24+
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
25+
)
2426
@assert !SciMLBase.isinplace(prob) "`Bisection` only supports out-of-place problems."
2527

2628
f = Base.Fix2(prob.f, prob.p)
@@ -32,20 +34,23 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Bisection,
3234

3335
if iszero(fl)
3436
return SciMLBase.build_solution(
35-
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right)
37+
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right
38+
)
3639
end
3740

3841
if iszero(fr)
3942
return SciMLBase.build_solution(
40-
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right)
43+
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right
44+
)
4145
end
4246

4347
if sign(fl) == sign(fr)
4448
verbose &&
4549
@warn "The interval is not an enclosing interval, opposite signs at the \
4650
boundaries are required."
4751
return SciMLBase.build_solution(
48-
prob, alg, left, fl; retcode = ReturnCode.InitialFailure, left, right)
52+
prob, alg, left, fl; retcode = ReturnCode.InitialFailure, left, right
53+
)
4954
end
5055

5156
i = 1
@@ -54,13 +59,15 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Bisection,
5459

5560
if mid == left || mid == right
5661
return SciMLBase.build_solution(
57-
prob, alg, left, fl; retcode = ReturnCode.FloatingPointLimit, left, right)
62+
prob, alg, left, fl; retcode = ReturnCode.FloatingPointLimit, left, right
63+
)
5864
end
5965

6066
fm = f(mid)
6167
if abs((right - left) / 2) < abstol
6268
return SciMLBase.build_solution(
63-
prob, alg, mid, fm; retcode = ReturnCode.Success, left, right)
69+
prob, alg, mid, fm; retcode = ReturnCode.Success, left, right
70+
)
6471
end
6572

6673
if iszero(fm)
@@ -80,10 +87,12 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Bisection,
8087
end
8188

8289
sol, i, left, right, fl, fr = Impl.bisection(
83-
left, right, fl, fr, f, abstol, maxiters - i, prob, alg)
90+
left, right, fl, fr, f, abstol, maxiters - i, prob, alg
91+
)
8492

8593
sol !== nothing && return sol
8694

8795
return SciMLBase.build_solution(
88-
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right)
96+
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right
97+
)
8998
end

lib/BracketingNonlinearSolve/src/brent.jl

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ Left non-allocating Brent method.
55
"""
66
struct Brent <: AbstractBracketingAlgorithm end
77

8-
function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
9-
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...)
8+
function CommonSolve.solve(
9+
prob::IntervalNonlinearProblem, alg::Brent, args...;
10+
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
11+
)
1012
@assert !SciMLBase.isinplace(prob) "`Brent` only supports out-of-place problems."
1113

1214
f = Base.Fix2(prob.f, prob.p)
@@ -15,24 +17,28 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
1517
ϵ = eps(convert(typeof(fl), 1))
1618

1719
abstol = NonlinearSolveBase.get_tolerance(
18-
left, abstol, promote_type(eltype(left), eltype(right)))
20+
left, abstol, promote_type(eltype(left), eltype(right))
21+
)
1922

2023
if iszero(fl)
2124
return SciMLBase.build_solution(
22-
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right)
25+
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right
26+
)
2327
end
2428

2529
if iszero(fr)
2630
return SciMLBase.build_solution(
27-
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right)
31+
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right
32+
)
2833
end
2934

3035
if sign(fl) == sign(fr)
3136
verbose &&
3237
@warn "The interval is not an enclosing interval, opposite signs at the \
3338
boundaries are required."
3439
return SciMLBase.build_solution(
35-
prob, alg, left, fl; retcode = ReturnCode.InitialFailure, left, right)
40+
prob, alg, left, fl; retcode = ReturnCode.InitialFailure, left, right
41+
)
3642
end
3743

3844
if abs(fl) < abs(fr)
@@ -67,8 +73,10 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
6773
# Bisection method
6874
s = (left + right) / 2
6975
if s == left || s == right
70-
return SciMLBase.build_solution(prob, alg, left, fl;
71-
retcode = ReturnCode.FloatingPointLimit, left, right)
76+
return SciMLBase.build_solution(
77+
prob, alg, left, fl;
78+
retcode = ReturnCode.FloatingPointLimit, left, right
79+
)
7280
end
7381
cond = true
7482
else
@@ -78,7 +86,8 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
7886
fs = f(s)
7987
if abs((right - left) / 2) < abstol
8088
return SciMLBase.build_solution(
81-
prob, alg, s, fs; retcode = ReturnCode.Success, left, right)
89+
prob, alg, s, fs; retcode = ReturnCode.Success, left, right
90+
)
8291
end
8392

8493
if iszero(fs)
@@ -110,10 +119,12 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
110119
end
111120

112121
sol, i, left, right, fl, fr = Impl.bisection(
113-
left, right, fl, fr, f, abstol, maxiters - i, prob, alg)
122+
left, right, fl, fr, f, abstol, maxiters - i, prob, alg
123+
)
114124

115125
sol !== nothing && return sol
116126

117127
return SciMLBase.build_solution(
118-
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right)
128+
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right
129+
)
119130
end

lib/BracketingNonlinearSolve/src/common.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@ function bisection(left, right, fl, fr, f::F, abstol, maxiters, prob, alg) where
1010

1111
if mid == left || mid == right
1212
sol = SciMLBase.build_solution(
13-
prob, alg, left, fl; left, right, retcode = ReturnCode.FloatingPointLimit)
13+
prob, alg, left, fl; left, right, retcode = ReturnCode.FloatingPointLimit
14+
)
1415
break
1516
end
1617

1718
fm = f(mid)
1819
if abs((right - left) / 2) < abstol
1920
sol = SciMLBase.build_solution(
20-
prob, alg, mid, fm; left, right, retcode = ReturnCode.Success)
21+
prob, alg, mid, fm; left, right, retcode = ReturnCode.Success
22+
)
2123
break
2224
end
2325

lib/BracketingNonlinearSolve/src/falsi.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ A non-allocating regula falsi method.
55
"""
66
struct Falsi <: AbstractBracketingAlgorithm end
77

8-
function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
9-
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...)
8+
function CommonSolve.solve(
9+
prob::IntervalNonlinearProblem, alg::Falsi, args...;
10+
maxiters = 1000, abstol = nothing, verbose::Bool = true, kwargs...
11+
)
1012
@assert !SciMLBase.isinplace(prob) "`False` only supports out-of-place problems."
1113

1214
f = Base.Fix2(prob.f, prob.p)
@@ -19,27 +21,31 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
1921

2022
if iszero(fl)
2123
return SciMLBase.build_solution(
22-
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right)
24+
prob, alg, left, fl; retcode = ReturnCode.ExactSolutionLeft, left, right
25+
)
2326
end
2427

2528
if iszero(fr)
2629
return SciMLBase.build_solution(
27-
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right)
30+
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight, left, right
31+
)
2832
end
2933

3034
if sign(fl) == sign(fr)
3135
verbose &&
3236
@warn "The interval is not an enclosing interval, opposite signs at the \
3337
boundaries are required."
3438
return SciMLBase.build_solution(
35-
prob, alg, left, fl; retcode = ReturnCode.InitialFailure, left, right)
39+
prob, alg, left, fl; retcode = ReturnCode.InitialFailure, left, right
40+
)
3641
end
3742

3843
i = 1
3944
while i maxiters
4045
if Impl.nextfloat_tdir(left, l, r) == right
4146
return SciMLBase.build_solution(
42-
prob, alg, left, fl; left, right, retcode = ReturnCode.FloatingPointLimit)
47+
prob, alg, left, fl; left, right, retcode = ReturnCode.FloatingPointLimit
48+
)
4349
end
4450

4551
mid = (fr * left - fl * right) / (fr - fl)
@@ -52,7 +58,8 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
5258
fm = f(mid)
5359
if abs((right - left) / 2) < abstol
5460
return SciMLBase.build_solution(
55-
prob, alg, mid, fm; left, right, retcode = ReturnCode.Success)
61+
prob, alg, mid, fm; left, right, retcode = ReturnCode.Success
62+
)
5663
end
5764

5865
if abs(fm) < abstol
@@ -70,10 +77,12 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
7077
end
7178

7279
sol, i, left, right, fl, fr = Impl.bisection(
73-
left, right, fl, fr, f, abstol, maxiters - i, prob, alg)
80+
left, right, fl, fr, f, abstol, maxiters - i, prob, alg
81+
)
7482

7583
sol !== nothing && return sol
7684

7785
return SciMLBase.build_solution(
78-
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right)
86+
prob, alg, left, fl; retcode = ReturnCode.MaxIters, left, right
87+
)
7988
end

0 commit comments

Comments
 (0)