Skip to content

Commit 6f043bf

Browse files
authored
refactor: Move NonlinearSolvePolyAlgorithm to Base (#494)
* refactor: Move NonlinearSolvePolyAlgorithm to Base * test: Make NonlinearSolve use 1.3 Base * refactor: Remove unnecessary snippet * refactor: Don't use duplicate solve * refactor: Test Base export NonlinearSolvePolyAlgorithm
1 parent 037a07c commit 6f043bf

File tree

8 files changed

+552
-549
lines changed

8 files changed

+552
-549
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ NLSolvers = "0.5"
8989
NLsolve = "4.5"
9090
NaNMath = "1"
9191
NonlinearProblemLibrary = "0.1.2"
92-
NonlinearSolveBase = "1.2"
92+
NonlinearSolveBase = "1.3"
9393
NonlinearSolveFirstOrder = "1"
9494
NonlinearSolveQuasiNewton = "1"
9595
NonlinearSolveSpectralMethods = "1"

lib/NonlinearSolveBase/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolveBase"
22
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.2.0"
4+
version = "1.3.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ include("linear_solve.jl")
4747
include("timer_outputs.jl")
4848
include("tracing.jl")
4949
include("wrappers.jl")
50+
include("polyalg.jl")
5051

5152
include("descent/common.jl")
5253
include("descent/newton.jl")
@@ -81,4 +82,6 @@ export RelTerminationMode, AbsTerminationMode,
8182
export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogleg,
8283
GeodesicAcceleration
8384

85+
export NonlinearSolvePolyAlgorithm
86+
8487
end

lib/NonlinearSolveBase/src/polyalg.jl

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
"""
2+
NonlinearSolvePolyAlgorithm(algs; start_index::Int = 1)
3+
4+
A general way to define PolyAlgorithms for `NonlinearProblem` and
5+
`NonlinearLeastSquaresProblem`. This is a container for a tuple of algorithms that will be
6+
tried in order until one succeeds. If none succeed, then the algorithm with the lowest
7+
residual is returned.
8+
9+
### Arguments
10+
11+
- `algs`: a tuple of algorithms to try in-order! (If this is not a Tuple, then the
12+
returned algorithm is not type-stable).
13+
14+
### Keyword Arguments
15+
16+
- `start_index`: the index to start at. Defaults to `1`.
17+
18+
### Example
19+
20+
```julia
21+
using NonlinearSolve
22+
23+
alg = NonlinearSolvePolyAlgorithm((NewtonRaphson(), Broyden()))
24+
```
25+
"""
26+
@concrete struct NonlinearSolvePolyAlgorithm <: AbstractNonlinearSolveAlgorithm
27+
static_length <: Val
28+
algs <: Tuple
29+
start_index::Int
30+
end
31+
32+
function NonlinearSolvePolyAlgorithm(algs; start_index::Int = 1)
33+
@assert 0 < start_index length(algs)
34+
algs = Tuple(algs)
35+
return NonlinearSolvePolyAlgorithm(Val(length(algs)), algs, start_index)
36+
end
37+
38+
@concrete mutable struct NonlinearSolvePolyAlgorithmCache <: AbstractNonlinearSolveCache
39+
static_length <: Val
40+
prob <: AbstractNonlinearProblem
41+
42+
caches <: Tuple
43+
alg <: NonlinearSolvePolyAlgorithm
44+
45+
best::Int
46+
current::Int
47+
nsteps::Int
48+
49+
stats::NLStats
50+
total_time::Float64
51+
maxtime
52+
53+
retcode::ReturnCode.T
54+
force_stop::Bool
55+
56+
maxiters::Int
57+
internalnorm
58+
59+
u0
60+
u0_aliased
61+
alias_u0::Bool
62+
end
63+
64+
function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache)
65+
return cache.caches[cache.current]
66+
end
67+
SII.state_values(cache::NonlinearSolvePolyAlgorithmCache) = cache.u0
68+
69+
function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache)
70+
println(io, "NonlinearSolvePolyAlgorithmCache with \
71+
$(Utils.unwrap_val(cache.static_length)) algorithms:")
72+
best_alg = ifelse(cache.best == -1, "nothing", cache.best)
73+
println(io, " Best Algorithm: $(best_alg)")
74+
println(
75+
io, " Current Algorithm: [$(cache.current) / $(Utils.unwrap_val(cache.static_length))]"
76+
)
77+
println(io, " nsteps: $(cache.nsteps)")
78+
println(io, " retcode: $(cache.retcode)")
79+
print(io, " Current Cache: ")
80+
NonlinearSolveBase.show_nonlinearsolve_cache(io, cache.caches[cache.current], 4)
81+
end
82+
83+
function InternalAPI.reinit!(
84+
cache::NonlinearSolvePolyAlgorithmCache, args...; p = cache.p, u0 = cache.u0
85+
)
86+
foreach(cache.caches) do cache
87+
InternalAPI.reinit!(cache, args...; p, u0)
88+
end
89+
cache.current = cache.alg.start_index
90+
InternalAPI.reinit!(cache.stats)
91+
cache.nsteps = 0
92+
cache.total_time = 0.0
93+
end
94+
95+
function SciMLBase.__init(
96+
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...;
97+
stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000,
98+
internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs...
99+
)
100+
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
101+
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
102+
immutable (checked using `ArrayInterface.ismutable`)."
103+
alias_u0 = false # If immutable don't care about aliasing
104+
end
105+
106+
u0 = prob.u0
107+
u0_aliased = alias_u0 ? copy(u0) : u0
108+
alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased))
109+
110+
return NonlinearSolvePolyAlgorithmCache(
111+
alg.static_length, prob,
112+
map(alg.algs) do solver
113+
SciMLBase.__init(
114+
prob, solver, args...;
115+
stats, maxtime, internalnorm, alias_u0, verbose, kwargs...
116+
)
117+
end,
118+
alg, -1, alg.start_index, 0, stats, 0.0, maxtime,
119+
ReturnCode.Default, false, maxiters, internalnorm,
120+
u0, u0_aliased, alias_u0
121+
)
122+
end
123+
124+
@generated function InternalAPI.step!(
125+
cache::NonlinearSolvePolyAlgorithmCache{Val{N}}, args...; kwargs...
126+
) where {N}
127+
calls = []
128+
cache_syms = [gensym("cache") for i in 1:N]
129+
for i in 1:N
130+
push!(calls,
131+
quote
132+
$(cache_syms[i]) = cache.caches[$(i)]
133+
if $(i) == cache.current
134+
InternalAPI.step!($(cache_syms[i]), args...; kwargs...)
135+
$(cache_syms[i]).nsteps += 1
136+
if !NonlinearSolveBase.not_terminated($(cache_syms[i]))
137+
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
138+
cache.best = $(i)
139+
cache.force_stop = true
140+
cache.retcode = $(cache_syms[i]).retcode
141+
else
142+
cache.current = $(i + 1)
143+
end
144+
end
145+
return
146+
end
147+
end)
148+
end
149+
150+
push!(calls, quote
151+
if !(1 cache.current length(cache.caches))
152+
minfu, idx = findmin_caches(cache.prob, cache.caches)
153+
cache.best = idx
154+
cache.retcode = cache.caches[idx].retcode
155+
cache.force_stop = true
156+
return
157+
end
158+
end)
159+
160+
return Expr(:block, calls...)
161+
end
162+
163+
# Original is often determined on runtime information especially for PolyAlgorithms so it
164+
# is best to never specialize on that
165+
function build_solution_less_specialize(
166+
prob::AbstractNonlinearProblem, alg, u, resid;
167+
retcode = ReturnCode.Default, original = nothing, left = nothing,
168+
right = nothing, stats = nothing, trace = nothing, kwargs...
169+
)
170+
return SciMLBase.NonlinearSolution{
171+
eltype(eltype(u)), ndims(u), typeof(u), typeof(resid), typeof(prob),
172+
typeof(alg), Any, typeof(left), typeof(stats), typeof(trace)
173+
}(
174+
u, resid, prob, alg, retcode, original, left, right, stats, trace
175+
)
176+
end
177+
178+
function findmin_caches(prob::AbstractNonlinearProblem, caches)
179+
resids = map(caches) do cache
180+
cache === nothing && return nothing
181+
return NonlinearSolveBase.get_fu(cache)
182+
end
183+
return findmin_resids(prob, resids)
184+
end
185+
186+
@views function findmin_resids(prob::AbstractNonlinearProblem, caches)
187+
norm_fn = prob isa NonlinearLeastSquaresProblem ? Base.Fix2(norm, 2) :
188+
Base.Fix2(norm, Inf)
189+
idx = findfirst(Base.Fix2(!==, nothing), caches)
190+
# This is an internal function so we assume that inputs are consistent and there is
191+
# atleast one non-`nothing` value
192+
fx_idx = norm_fn(caches[idx])
193+
idx == length(caches) && return fx_idx, idx
194+
fmin = @closure xᵢ -> begin
195+
xᵢ === nothing && return oftype(fx_idx, Inf)
196+
fx = norm_fn(xᵢ)
197+
return ifelse(isnan(fx), oftype(fx, Inf), fx)
198+
end
199+
x_min, x_min_idx = findmin(fmin, caches[(idx + 1):length(caches)])
200+
x_min < fx_idx && return x_min, x_min_idx + idx
201+
return fx_idx, idx
202+
end

0 commit comments

Comments
 (0)