Skip to content

Commit 98ef805

Browse files
committed
feat: functional Klement
1 parent ca8e8bc commit 98ef805

File tree

7 files changed

+55
-25
lines changed

7 files changed

+55
-25
lines changed

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
1010
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
1111
NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction,
1212
@add_kwonly, StandardNonlinearProblem, NullParameters, NonlinearProblem,
13-
isinplace
13+
isinplace, warn_paramtype
1414
using StaticArraysCore: StaticArray
1515

1616
include("public.jl")

lib/NonlinearSolveBase/src/public.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ for name in (:Norm, :RelNorm, :AbsNorm)
5151

5252
@eval begin
5353
"""
54-
$($struct_name) <: AbstractSafeNonlinearTerminationMode
54+
$($struct_name) <: AbstractNonlinearTerminationMode
5555
5656
Terminates if $($doctring).
5757
@@ -63,7 +63,7 @@ for name in (:Norm, :RelNorm, :AbsNorm)
6363
6464
$($TERM_INTERNALNORM_DOCS).
6565
"""
66-
struct $(struct_name){F} <: AbstractSafeNonlinearTerminationMode
66+
struct $(struct_name){F} <: AbstractNonlinearTerminationMode
6767
internalnorm::F
6868

6969
function $(struct_name)(internalnorm::F) where {F}

lib/NonlinearSolveBase/src/termination_conditions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,6 @@ function init_termination_cache(::AbstractNonlinearProblem, abstol, reltol, du,
276276
T = promote_type(eltype(du), eltype(u))
277277
abstol = get_tolerance(abstol, T)
278278
reltol = get_tolerance(reltol, T)
279-
cache = init(du, u, tc; abstol, reltol)
279+
cache = SciMLBase.init(du, u, tc; abstol, reltol)
280280
return abstol, reltol, cache
281281
end

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using FiniteDiff: FiniteDiff
1818
using ForwardDiff: ForwardDiff
1919

2020
using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
21-
using NonlinearSolveBase: ImmutableNonlinearProblem, get_tolerance
21+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, get_tolerance
2222

2323
const DI = DifferentiationInterface
2424

@@ -28,6 +28,8 @@ is_extension_loaded(::Val) = false
2828

2929
include("utils.jl")
3030

31+
include("klement.jl")
32+
3133
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
3234
function CommonSolve.solve(prob::NonlinearProblem,
3335
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
@@ -69,7 +71,7 @@ function solve_adjoint_internal end
6971
prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2))
7072
prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2))
7173

72-
algs = []
74+
algs = [SimpleKlement()]
7375
algs_no_iip = []
7476

7577
@compile_workload begin

lib/SimpleNonlinearSolve/src/klement.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement,
1111
alias_u0 = false, termination_condition = nothing, kwargs...)
1212
x = Utils.maybe_unaliased(prob.u0, alias_u0)
1313
T = eltype(x)
14+
fx = Utils.get_fx(prob, x)
1415

1516
abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
1617
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
@@ -32,8 +33,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement,
3233
fx = Utils.eval_f(prob, fx, x)
3334

3435
# Termination Checks
35-
# tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)
36-
tc_sol !== nothing && return tc_sol
36+
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
37+
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
3738

3839
@bb δx .*= -1
3940
@bb @. δx² = δx^2 * J^2

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@ using ArrayInterface: ArrayInterface
55
using DifferentiationInterface: DifferentiationInterface
66
using FastClosures: @closure
77
using LinearAlgebra: LinearAlgebra, I, diagind
8-
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem
8+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
9+
AbstractNonlinearTerminationMode,
10+
AbstractSafeNonlinearTerminationMode,
11+
AbstractSafeBestNonlinearTerminationMode
912
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearLeastSquaresProblem,
10-
NonlinearProblem, NonlinearFunction
13+
NonlinearProblem, NonlinearFunction, ReturnCode
1114
using StaticArraysCore: StaticArray, SArray, SMatrix, SVector
1215

1316
const DI = DifferentiationInterface
@@ -60,7 +63,7 @@ function get_fx(prob::Union{ImmutableNonlinearProblem, NonlinearProblem}, x)
6063
end
6164
function get_fx(f::NonlinearFunction, x, p)
6265
if SciMLBase.isinplace(f)
63-
f.resid_prototype === nothing && return eltype(x).(f.resid_prototype)
66+
f.resid_prototype === nothing || return eltype(x).(f.resid_prototype)
6467
return safe_similar(x)
6568
end
6669
return f(x, p)
@@ -77,18 +80,18 @@ function fixed_parameter_function(prob::AbstractNonlinearProblem)
7780
return Base.Fix2(prob.f, prob.p)
7881
end
7982

80-
# __init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α)
81-
# function __init_identity_jacobian(u, fu, α = true)
82-
# J = __similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u))
83-
# fill!(J, zero(eltype(J)))
84-
# J[diagind(J)] .= eltype(J)(α)
85-
# return J
86-
# end
87-
# function __init_identity_jacobian(u::StaticArray, fu, α = true)
88-
# S1, S2 = length(fu), length(u)
89-
# J = SMatrix{S1, S2, eltype(u)}(I * α)
90-
# return J
91-
# end
83+
function identity_jacobian(u::Number, fu::Number, α = true)
84+
return convert(promote_type(eltype(u), eltype(fu)), α)
85+
end
86+
function identity_jacobian(u, fu, α = true)
87+
J = safe_similar(u, promote_type(eltype(u), eltype(fu)))
88+
fill!(J, zero(eltype(J)))
89+
J[diagind(J)] .= eltype(J)(α)
90+
return J
91+
end
92+
function identity_jacobian(u::StaticArray, fu, α = true)
93+
return SMatrix{length(fu), length(u), eltype(u)}(I * α)
94+
end
9295

9396
identity_jacobian!!(J::Number) = one(J)
9497
function identity_jacobian!!(J::AbstractVector)
@@ -104,4 +107,28 @@ end
104107
identity_jacobian!!(::SMatrix{S1, S2, T}) where {S1, S2, T} = SMatrix{S1, S2, T}(I)
105108
identity_jacobian!!(::SVector{S1, T}) where {S1, T} = ones(SVector{S1, T})
106109

110+
# Termination Conditions
111+
function check_termination(cache, fx, x, xo, prob)
112+
return check_termination(cache, fx, x, xo, prob, cache.mode)
113+
end
114+
115+
function check_termination(cache, fx, x, xo, _, ::AbstractNonlinearTerminationMode)
116+
return cache(fx, x, xo), ReturnCode.Success, fx, x
117+
end
118+
function check_termination(cache, fx, x, xo, _, ::AbstractSafeNonlinearTerminationMode)
119+
return cache(fx, x, xo), cache.retcode, fx, x
120+
end
121+
function check_termination(cache, fx, x, xo, prob, ::AbstractSafeBestNonlinearTerminationMode)
122+
if cache(fx, x, xo)
123+
x = cache.u
124+
if SciMLBase.isinplace(prob)
125+
prob.f(fx, x, prob.p)
126+
else
127+
fx = prob.f(x, prob.p)
128+
end
129+
return true, cache.retcode, fx, x
130+
end
131+
return false, ReturnCode.Default, fx, x
132+
end
133+
107134
end

src/internal/termination.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ function update_from_termination_cache!(tc_cache, cache, u = get_u(cache))
4545
end
4646

4747
function update_from_termination_cache!(
48-
tc_cache, cache, mode::AbstractNonlinearTerminationMode, u = get_u(cache))
48+
_, cache, ::AbstractNonlinearTerminationMode, u = get_u(cache))
4949
evaluate_f!(cache, u, cache.p)
5050
end
5151

5252
function update_from_termination_cache!(
53-
tc_cache, cache, mode::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache))
53+
tc_cache, cache, ::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache))
5454
if isinplace(cache)
5555
copyto!(get_u(cache), tc_cache.u)
5656
else

0 commit comments

Comments
 (0)