Skip to content

Commit 9206a3f

Browse files
committed
Fixed Point Solvers working
1 parent 67aa8a8 commit 9206a3f

File tree

9 files changed

+161
-161
lines changed

9 files changed

+161
-161
lines changed
Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
11
module NonlinearSolveFixedPointAccelerationExt
22

3-
using NonlinearSolve, FixedPointAcceleration, DiffEqBase, SciMLBase
3+
using NonlinearSolve, FixedPointAcceleration, SciMLBase
44

55
function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
66
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
77
show_trace::Val{PrintReports} = Val(false), termination_condition = nothing,
88
kwargs...) where {PrintReports}
9-
@assert (termination_condition ===
10-
nothing)||(termination_condition isa AbsNormTerminationMode) "FixedPointAccelerationJL does not support termination conditions!"
11-
12-
f, u0 = NonlinearSolve.__construct_f(prob; alias_u0, make_fixed_point = Val(true),
13-
force_oop = Val(true))
9+
NonlinearSolve.__test_termination_condition(termination_condition,
10+
:FixedPointAccelerationJL)
1411

12+
f, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0,
13+
make_fixed_point = Val(true), force_oop = Val(true))
1514
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
1615

17-
sol = fixed_point(f, u0; Algorithm = alg.algorithm,
18-
ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m,
19-
ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening,
20-
PrintReports, ReplaceInvalids = alg.replace_invalids,
16+
sol = fixed_point(f, u0; Algorithm = alg.algorithm, MaxIter = maxiters, MaxM = alg.m,
17+
ConvergenceMetricThreshold = tol, ExtrapolationPeriod = alg.extrapolation_period,
18+
Dampening = alg.dampening, PrintReports, ReplaceInvalids = alg.replace_invalids,
2119
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true)
2220

2321
if sol.FixedPoint_ === missing
@@ -31,10 +29,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL
3129
resid = NonlinearSolve.evaluate_f(prob, res)
3230
converged = maximum(abs, resid) tol
3331
end
34-
return SciMLBase.build_solution(prob, alg, res, resid;
32+
33+
return SciMLBase.build_solution(prob, alg, res, resid; original = sol,
3534
retcode = converged ? ReturnCode.Success : ReturnCode.Failure,
36-
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_),
37-
original = sol)
35+
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_))
3836
end
3937

4038
end

ext/NonlinearSolveMINPACKExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ module NonlinearSolveMINPACKExt
22

33
using MINPACK, NonlinearSolve, SciMLBase
44
import FastClosures: @closure
5-
import SciMLBase: AbstractNonlinearProblem
65

7-
function SciMLBase.__solve(prob::AbstractNonlinearProblem, alg::CMINPACK, args...;
8-
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
9-
show_trace::Val{ShT} = Val(false), store_trace::Val{StT} = Val(false),
10-
termination_condition = nothing, kwargs...) where {uType, iip, ShT, StT}
6+
function SciMLBase.__solve(prob::Union{NonlinearLeastSquaresProblem,
7+
NonlinearProblem}, alg::CMINPACK, args...; abstol = nothing, maxiters = 1000,
8+
alias_u0::Bool = false, show_trace::Val{ShT} = Val(false),
9+
store_trace::Val{StT} = Val(false), termination_condition = nothing,
10+
kwargs...) where {ShT, StT}
1111
NonlinearSolve.__test_termination_condition(termination_condition, :CMINPACK)
1212

1313
_f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)

ext/NonlinearSolveSIAMFANLEquationsExt.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
4040
atol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, T)
4141
rtol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, T)
4242

43-
f, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0,
44-
can_handle_oop = Val(true), can_handle_scalar = Val(true),
45-
make_fixed_point = Val(method == :anderson))
46-
47-
if u isa Number
43+
if prob.u0 isa Number
44+
f = @closure u -> prob.f(u, prob.p)
4845
if method == :newton
49-
sol = nsolsc(f, u; maxit = maxiters, atol, rtol, printerr = ShT)
46+
sol = nsolsc(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT)
5047
elseif method == :pseudotransient
51-
sol = ptcsolsc(f, u; delta0 = delta, maxit = maxiters, atol, rtol,
48+
sol = ptcsolsc(f, prob.u0; delta0 = delta, maxit = maxiters, atol, rtol,
5249
printerr = ShT)
5350
elseif method == :secant
5451
sol = secant(f, u; maxit = maxiters, atol, rtol, printerr = ShT)
@@ -57,33 +54,30 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
5754
atol, rtol, beta)
5855
end
5956
else
57+
f, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
6058
N = length(u)
6159
FS = __zeros_like(u, N)
6260

6361
# Jacobian Free Newton Krylov
6462
if linsolve !== nothing
6563
# Allocate ahead for Krylov basis
6664
JVS = linsolve == :gmres ? __zeros_like(u, N, 3) : __zeros_like(u, N)
67-
# `linsolve` as a Symbol to keep unified interface with other EXTs,
68-
# SIAMFANLEquations directly use String to choose between different linear
69-
# solvers
7065
linsolve_alg = String(linsolve)
71-
7266
if method == :newton
73-
sol = nsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,
67+
sol = nsoli(f, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,
7468
rtol, printerr = ShT)
7569
elseif method == :pseudotransient
76-
sol = ptcsoli(f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters,
70+
sol = ptcsoli(f, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters,
7771
atol, rtol, printerr = ShT)
7872
end
7973
else
8074
if prob.f.jac === nothing && alg.autodiff === missing
8175
FPS = __zeros_like(u, N, N)
8276
if method == :newton
83-
sol = nsol(f!, u, FS, FPS; sham = 1, atol, rtol, maxit = maxiters,
77+
sol = nsol(f, u, FS, FPS; sham = 1, atol, rtol, maxit = maxiters,
8478
printerr = ShT)
8579
elseif method == :pseudotransient
86-
sol = ptcsol(f!, u, FS, FPS; atol, rtol, maxit = maxiters,
80+
sol = ptcsol(f, u, FS, FPS; atol, rtol, maxit = maxiters,
8781
delta0 = delta, printerr = ShT)
8882
elseif method == :anderson
8983
sol = aasol(f!, u, m, zeros(T, N, 2 * m + 4), atol, rtol,
@@ -92,14 +86,14 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
9286
else
9387
FPS = prob.f.jac_prototype !== nothing ? zero(prob.f.jac_prototype) :
9488
__zeros_like(u, N, N)
95-
jac! = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid;
89+
jac = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid;
9690
alg.autodiff)
97-
AJ! = @closure (J, u, x) -> jac!(J, x)
91+
AJ! = @closure (J, u, x) -> jac(J, x)
9892
if method == :newton
99-
sol = nsol(f!, u, FS, FPS, AJ!; sham = 1, atol, rtol, maxit = maxiters,
93+
sol = nsol(f, u, FS, FPS, AJ!; sham = 1, atol, rtol, maxit = maxiters,
10094
printerr = ShT)
10195
elseif method == :pseudotransient
102-
sol = ptcsol(f!, u, FS, FPS, AJ!; atol, rtol, maxit = maxiters,
96+
sol = ptcsol(f, u, FS, FPS, AJ!; atol, rtol, maxit = maxiters,
10397
delta0 = delta, printerr = ShT)
10498
end
10599
end

ext/NonlinearSolveSpeedMappingExt.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
module NonlinearSolveSpeedMappingExt
22

3-
using NonlinearSolve, SpeedMapping, DiffEqBase, SciMLBase
3+
using NonlinearSolve, SciMLBase, SpeedMapping
44

55
function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
6-
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
6+
abstol = nothing, maxiters = 1000, alias_u0::Bool = false, maxtime = nothing,
77
store_trace::Val{store_info} = Val(false), termination_condition = nothing,
88
kwargs...) where {store_info}
9-
@assert (termination_condition ===
10-
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!"
9+
NonlinearSolve.__test_termination_condition(termination_condition, :SpeedMappingJL)
1110

12-
m!, u0 = NonlinearSolve.__construct_f(prob; alias_u0, make_fixed_point = Val(true),
13-
can_handle_arbitrary_dims = Val(true))
11+
m!, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0,
12+
make_fixed_point = Val(true))
13+
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
1414

15-
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
15+
time_limit = ifelse(maxtime === nothing, alg.time_limit, maxtime)
1616

17-
sol = speedmapping(u0; m!, tol, Lp = Inf, maps_limit = maxiters, alg.orders,
18-
alg.check_obj, store_info, alg.σ_min, alg.stabilize)
17+
sol = speedmapping(u; m!, tol, Lp = Inf, maps_limit = maxiters, alg.orders,
18+
alg.check_obj, store_info, alg.σ_min, alg.stabilize, time_limit)
1919
res = prob.u0 isa Number ? first(sol.minimizer) : sol.minimizer
2020
resid = NonlinearSolve.evaluate_f(prob, res)
2121

22-
return SciMLBase.build_solution(prob, alg, res, resid;
22+
return SciMLBase.build_solution(prob, alg, res, resid; original = sol,
2323
retcode = sol.converged ? ReturnCode.Success : ReturnCode.Failure,
24-
stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps), original = sol)
24+
stats = SciMLBase.NLStats(sol.maps, 0, 0, 0, sol.maps))
2525
end
2626

2727
end

0 commit comments

Comments
 (0)