|
| 1 | +""" |
| 2 | + NonlinearSolvePolyAlgorithm(algs; start_index::Int = 1) |
| 3 | +
|
| 4 | +A general way to define PolyAlgorithms for `NonlinearProblem` and |
| 5 | +`NonlinearLeastSquaresProblem`. This is a container for a tuple of algorithms that will be |
| 6 | +tried in order until one succeeds. If none succeed, then the algorithm with the lowest |
| 7 | +residual is returned. |
| 8 | +
|
| 9 | +### Arguments |
| 10 | +
|
| 11 | + - `algs`: a tuple of algorithms to try in-order! (If this is not a Tuple, then the |
| 12 | + returned algorithm is not type-stable). |
| 13 | +
|
| 14 | +### Keyword Arguments |
| 15 | +
|
| 16 | + - `start_index`: the index to start at. Defaults to `1`. |
| 17 | +
|
| 18 | +### Example |
| 19 | +
|
| 20 | +```julia |
| 21 | +using NonlinearSolve |
| 22 | +
|
| 23 | +alg = NonlinearSolvePolyAlgorithm((NewtonRaphson(), Broyden())) |
| 24 | +``` |
| 25 | +""" |
| 26 | +@concrete struct NonlinearSolvePolyAlgorithm <: AbstractNonlinearSolveAlgorithm |
| 27 | + static_length <: Val |
| 28 | + algs <: Tuple |
| 29 | + start_index::Int |
| 30 | +end |
| 31 | + |
| 32 | +function NonlinearSolvePolyAlgorithm(algs; start_index::Int = 1) |
| 33 | + @assert 0 < start_index ≤ length(algs) |
| 34 | + algs = Tuple(algs) |
| 35 | + return NonlinearSolvePolyAlgorithm(Val(length(algs)), algs, start_index) |
| 36 | +end |
| 37 | + |
| 38 | +@concrete mutable struct NonlinearSolvePolyAlgorithmCache <: AbstractNonlinearSolveCache |
| 39 | + static_length <: Val |
| 40 | + prob <: AbstractNonlinearProblem |
| 41 | + |
| 42 | + caches <: Tuple |
| 43 | + alg <: NonlinearSolvePolyAlgorithm |
| 44 | + |
| 45 | + best::Int |
| 46 | + current::Int |
| 47 | + nsteps::Int |
| 48 | + |
| 49 | + stats::NLStats |
| 50 | + total_time::Float64 |
| 51 | + maxtime |
| 52 | + |
| 53 | + retcode::ReturnCode.T |
| 54 | + force_stop::Bool |
| 55 | + |
| 56 | + maxiters::Int |
| 57 | + internalnorm |
| 58 | + |
| 59 | + u0 |
| 60 | + u0_aliased |
| 61 | + alias_u0::Bool |
| 62 | +end |
| 63 | + |
| 64 | +function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache) |
| 65 | + return cache.caches[cache.current] |
| 66 | +end |
| 67 | +SII.state_values(cache::NonlinearSolvePolyAlgorithmCache) = cache.u0 |
| 68 | + |
| 69 | +function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache) |
| 70 | + println(io, "NonlinearSolvePolyAlgorithmCache with \ |
| 71 | + $(Utils.unwrap_val(cache.static_length)) algorithms:") |
| 72 | + best_alg = ifelse(cache.best == -1, "nothing", cache.best) |
| 73 | + println(io, " Best Algorithm: $(best_alg)") |
| 74 | + println( |
| 75 | + io, " Current Algorithm: [$(cache.current) / $(Utils.unwrap_val(cache.static_length))]" |
| 76 | + ) |
| 77 | + println(io, " nsteps: $(cache.nsteps)") |
| 78 | + println(io, " retcode: $(cache.retcode)") |
| 79 | + print(io, " Current Cache: ") |
| 80 | + NonlinearSolveBase.show_nonlinearsolve_cache(io, cache.caches[cache.current], 4) |
| 81 | +end |
| 82 | + |
| 83 | +function InternalAPI.reinit!( |
| 84 | + cache::NonlinearSolvePolyAlgorithmCache, args...; p = cache.p, u0 = cache.u0 |
| 85 | +) |
| 86 | + foreach(cache.caches) do cache |
| 87 | + InternalAPI.reinit!(cache, args...; p, u0) |
| 88 | + end |
| 89 | + cache.current = cache.alg.start_index |
| 90 | + InternalAPI.reinit!(cache.stats) |
| 91 | + cache.nsteps = 0 |
| 92 | + cache.total_time = 0.0 |
| 93 | +end |
| 94 | + |
| 95 | +function SciMLBase.__init( |
| 96 | + prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...; |
| 97 | + stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000, |
| 98 | + internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs... |
| 99 | +) |
| 100 | + if alias_u0 && !ArrayInterface.ismutable(prob.u0) |
| 101 | + verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \ |
| 102 | + immutable (checked using `ArrayInterface.ismutable`)." |
| 103 | + alias_u0 = false # If immutable don't care about aliasing |
| 104 | + end |
| 105 | + |
| 106 | + u0 = prob.u0 |
| 107 | + u0_aliased = alias_u0 ? copy(u0) : u0 |
| 108 | + alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased)) |
| 109 | + |
| 110 | + return NonlinearSolvePolyAlgorithmCache( |
| 111 | + alg.static_length, prob, |
| 112 | + map(alg.algs) do solver |
| 113 | + SciMLBase.__init( |
| 114 | + prob, solver, args...; |
| 115 | + stats, maxtime, internalnorm, alias_u0, verbose, kwargs... |
| 116 | + ) |
| 117 | + end, |
| 118 | + alg, -1, alg.start_index, 0, stats, 0.0, maxtime, |
| 119 | + ReturnCode.Default, false, maxiters, internalnorm, |
| 120 | + u0, u0_aliased, alias_u0 |
| 121 | + ) |
| 122 | +end |
| 123 | + |
| 124 | +@generated function InternalAPI.step!( |
| 125 | + cache::NonlinearSolvePolyAlgorithmCache{Val{N}}, args...; kwargs... |
| 126 | +) where {N} |
| 127 | + calls = [] |
| 128 | + cache_syms = [gensym("cache") for i in 1:N] |
| 129 | + for i in 1:N |
| 130 | + push!(calls, |
| 131 | + quote |
| 132 | + $(cache_syms[i]) = cache.caches[$(i)] |
| 133 | + if $(i) == cache.current |
| 134 | + InternalAPI.step!($(cache_syms[i]), args...; kwargs...) |
| 135 | + $(cache_syms[i]).nsteps += 1 |
| 136 | + if !NonlinearSolveBase.not_terminated($(cache_syms[i])) |
| 137 | + if SciMLBase.successful_retcode($(cache_syms[i]).retcode) |
| 138 | + cache.best = $(i) |
| 139 | + cache.force_stop = true |
| 140 | + cache.retcode = $(cache_syms[i]).retcode |
| 141 | + else |
| 142 | + cache.current = $(i + 1) |
| 143 | + end |
| 144 | + end |
| 145 | + return |
| 146 | + end |
| 147 | + end) |
| 148 | + end |
| 149 | + |
| 150 | + push!(calls, quote |
| 151 | + if !(1 ≤ cache.current ≤ length(cache.caches)) |
| 152 | + minfu, idx = findmin_caches(cache.prob, cache.caches) |
| 153 | + cache.best = idx |
| 154 | + cache.retcode = cache.caches[idx].retcode |
| 155 | + cache.force_stop = true |
| 156 | + return |
| 157 | + end |
| 158 | + end) |
| 159 | + |
| 160 | + return Expr(:block, calls...) |
| 161 | +end |
| 162 | + |
| 163 | +# Original is often determined on runtime information especially for PolyAlgorithms so it |
| 164 | +# is best to never specialize on that |
| 165 | +function build_solution_less_specialize( |
| 166 | + prob::AbstractNonlinearProblem, alg, u, resid; |
| 167 | + retcode = ReturnCode.Default, original = nothing, left = nothing, |
| 168 | + right = nothing, stats = nothing, trace = nothing, kwargs... |
| 169 | +) |
| 170 | + return SciMLBase.NonlinearSolution{ |
| 171 | + eltype(eltype(u)), ndims(u), typeof(u), typeof(resid), typeof(prob), |
| 172 | + typeof(alg), Any, typeof(left), typeof(stats), typeof(trace) |
| 173 | + }( |
| 174 | + u, resid, prob, alg, retcode, original, left, right, stats, trace |
| 175 | + ) |
| 176 | +end |
| 177 | + |
| 178 | +function findmin_caches(prob::AbstractNonlinearProblem, caches) |
| 179 | + resids = map(caches) do cache |
| 180 | + cache === nothing && return nothing |
| 181 | + return NonlinearSolveBase.get_fu(cache) |
| 182 | + end |
| 183 | + return findmin_resids(prob, resids) |
| 184 | +end |
| 185 | + |
| 186 | +@views function findmin_resids(prob::AbstractNonlinearProblem, caches) |
| 187 | + norm_fn = prob isa NonlinearLeastSquaresProblem ? Base.Fix2(norm, 2) : |
| 188 | + Base.Fix2(norm, Inf) |
| 189 | + idx = findfirst(Base.Fix2(!==, nothing), caches) |
| 190 | + # This is an internal function so we assume that inputs are consistent and there is |
| 191 | + # atleast one non-`nothing` value |
| 192 | + fx_idx = norm_fn(caches[idx]) |
| 193 | + idx == length(caches) && return fx_idx, idx |
| 194 | + fmin = @closure xᵢ -> begin |
| 195 | + xᵢ === nothing && return oftype(fx_idx, Inf) |
| 196 | + fx = norm_fn(xᵢ) |
| 197 | + return ifelse(isnan(fx), oftype(fx, Inf), fx) |
| 198 | + end |
| 199 | + x_min, x_min_idx = findmin(fmin, caches[(idx + 1):length(caches)]) |
| 200 | + x_min < fx_idx && return x_min, x_min_idx + idx |
| 201 | + return fx_idx, idx |
| 202 | +end |
0 commit comments