Skip to content

Commit fc2d60c

Browse files
committed
feat: SimpleNewtonRaphson
1 parent 1d7df8d commit fc2d60c

File tree

6 files changed

+157
-11
lines changed

6 files changed

+157
-11
lines changed

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ AlgebraicMultigrid = "0.5, 0.6"
3030
ArrayInterface = "6, 7"
3131
BenchmarkTools = "1"
3232
DiffEqBase = "6.136"
33-
DifferentiationInterface = "0.6"
33+
DifferentiationInterface = "0.6.1"
3434
Documenter = "1"
3535
DocumenterCitations = "1"
3636
DocumenterInterLinks = "1.0.0"

lib/NonlinearSolveBase/src/autodiff.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33

44
# Ordering is important here. We want to select the first one that is compatible with the
55
# problem.
6-
const ReverseADs = [
6+
const ReverseADs = (
77
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
88
ADTypes.AutoZygote(),
99
ADTypes.AutoTracker(),
1010
ADTypes.AutoReverseDiff(; compile = true),
1111
ADTypes.AutoReverseDiff(),
1212
ADTypes.AutoFiniteDiff()
13-
]
13+
)
1414

15-
const ForwardADs = [
15+
const ForwardADs = (
1616
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
1717
ADTypes.AutoPolyesterForwardDiff(),
1818
ADTypes.AutoForwardDiff(),
1919
ADTypes.AutoFiniteDiff()
20-
]
20+
)
2121

2222
# TODO: Handle Sparsity
2323

@@ -28,7 +28,8 @@ function select_forward_mode_autodiff(
2828
end
2929
if incompatible_backend_and_problem(prob, ad)
3030
adₙ = select_forward_mode_autodiff(prob, nothing; warn_check_mode)
31-
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
31+
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \
32+
could be because the backend package for the choosen AD isn't loaded. After \
3233
running autodiff selection detected `$(adₙ)` as a potential forward mode \
3334
backend."
3435
return adₙ
@@ -57,7 +58,8 @@ function select_reverse_mode_autodiff(
5758
end
5859
if incompatible_backend_and_problem(prob, ad)
5960
adₙ = select_reverse_mode_autodiff(prob, nothing; warn_check_mode)
60-
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
61+
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \
62+
could be because the backend package for the choosen AD isn't loaded. After \
6163
running autodiff selection detected `$(adₙ)` as a potential reverse mode \
6264
backend."
6365
return adₙ
@@ -77,7 +79,8 @@ end
7779
function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ad::AbstractADType)
7880
if incompatible_backend_and_problem(prob, ad)
7981
adₙ = select_jacobian_autodiff(prob, nothing)
80-
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
82+
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \
83+
could be because the backend package for the choosen AD isn't loaded. After \
8184
running autodiff selection detected `$(adₙ)` as a potential jacobian \
8285
backend."
8386
return adₙ

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
1010
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
11+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1112
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1213
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1314
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
@@ -38,11 +39,13 @@ ArrayInterface = "7.16"
3839
BracketingNonlinearSolve = "1"
3940
ChainRulesCore = "1.24"
4041
CommonSolve = "0.2.4"
42+
ConcreteStructs = "0.2.3"
4143
DiffEqBase = "6.155"
42-
DifferentiationInterface = "0.5.17"
44+
DifferentiationInterface = "0.6.1"
4345
FastClosures = "0.3.2"
4446
FiniteDiff = "2.24.0"
4547
ForwardDiff = "0.10.36"
48+
InteractiveUtils = "<0.0.1, 1"
4649
LinearAlgebra = "1.10"
4750
MaybeInplace = "0.1.4"
4851
NonlinearSolveBase = "1"
@@ -51,6 +54,8 @@ Reexport = "1.2"
5154
ReverseDiff = "1.15"
5255
SciMLBase = "2.50"
5356
StaticArraysCore = "1.4.3"
57+
Test = "1.10"
58+
TestItemRunner = "1"
5459
Tracker = "0.2.35"
5560
julia = "1.10"
5661

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SimpleNonlinearSolve
22

33
using CommonSolve: CommonSolve, solve
4+
using ConcreteStructs: @concrete
45
using FastClosures: @closure
56
using MaybeInplace: @bb
67
using PrecompileTools: @compile_workload, @setup_workload
@@ -27,6 +28,7 @@ is_extension_loaded(::Val) = false
2728
include("utils.jl")
2829

2930
include("klement.jl")
31+
include("raphson.jl")
3032

3133
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
3234
function CommonSolve.solve(prob::NonlinearProblem,
@@ -69,7 +71,10 @@ function solve_adjoint_internal end
6971
prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2))
7072
prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2))
7173

72-
algs = [SimpleKlement()]
74+
algs = [
75+
SimpleKlement(),
76+
SimpleNewtonRaphson()
77+
]
7378
algs_no_iip = []
7479

7580
@compile_workload begin
@@ -87,4 +92,7 @@ export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
8792

8893
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
8994

95+
export SimpleKlement
96+
export SimpleGaussNewton, SimpleNewtonRaphson
97+
9098
end
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
SimpleNewtonRaphson(autodiff)
3+
SimpleNewtonRaphson(; autodiff = nothing)
4+
5+
A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar
6+
and static array problems.
7+
8+
!!! note
9+
10+
As part of the decreased overhead, this method omits some of the higher level error
11+
catching of the other methods. Thus, to see better error messages, use one of the other
12+
methods like `NewtonRaphson`.
13+
14+
### Keyword Arguments
15+
16+
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
17+
automatic backend selection). Valid choices include jacobian backends from
18+
`DifferentiationInterface.jl`.
19+
"""
20+
@kwdef @concrete struct SimpleNewtonRaphson <: AbstractSimpleNonlinearSolveAlgorithm
21+
autodiff = nothing
22+
end
23+
24+
const SimpleGaussNewton = SimpleNewtonRaphson
25+
26+
function SciMLBase.__solve(
27+
prob::ImmutableNonlinearProblem, alg::SimpleNewtonRaphson, args...;
28+
abstol = nothing, reltol = nothing, maxiters = 1000,
29+
alias_u0 = false, termination_condition = nothing, kwargs...)
30+
x = Utils.maybe_unaliased(prob.u0, alias_u0)
31+
fx = Utils.get_fx(prob, x)
32+
fx = Utils.eval_f(prob, fx, x)
33+
34+
iszero(fx) &&
35+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
36+
37+
abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
38+
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
39+
40+
autodiff = SciMLBase.has_jac(prob.f) ? alg.autodiff :
41+
NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
42+
43+
@bb xo = similar(x)
44+
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) :
45+
nothing
46+
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
47+
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
48+
49+
for _ in 1:maxiters
50+
@bb copyto!(xo, x)
51+
δx = Utils.restructure(x, J \ Utils.safe_vec(fx))
52+
@bb x .-= δx
53+
54+
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
55+
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
56+
57+
fx = Utils.eval_f(prob, fx, x)
58+
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
59+
end
60+
61+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
62+
end

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Utils
22

33
using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff
44
using ArrayInterface: ArrayInterface
5-
using DifferentiationInterface: DifferentiationInterface
5+
using DifferentiationInterface: DifferentiationInterface, Constant
66
using FastClosures: @closure
77
using LinearAlgebra: LinearAlgebra, I, diagind
88
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
@@ -132,4 +132,72 @@ function check_termination(
132132
return false, ReturnCode.Default, fx, x
133133
end
134134

135+
restructure(y, x) = ArrayInterface.restructure(y, x)
136+
restructure(::Number, x::Number) = x
137+
138+
safe_vec(x::AbstractArray) = vec(x)
139+
safe_vec(x::Number) = x
140+
141+
function prepare_jacobian(prob, autodiff, _, x::Number)
142+
if SciMLBase.has_jac(prob.f) || SciMLBase.has_vjp(prob.f) || SciMLBase.has_jvp(prob.f)
143+
return nothing
144+
end
145+
return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p))
146+
end
147+
function prepare_jacobian(prob, autodiff, fx, x)
148+
if SciMLBase.has_jac(prob.f)
149+
return nothing
150+
end
151+
if SciMLBase.isinplace(prob.f)
152+
return DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p))
153+
else
154+
return DI.prepare_jacobian(prob.f, autodiff, x, Constant(prob.p))
155+
end
156+
end
157+
158+
function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras)
159+
if extras === nothing
160+
if SciMLBase.has_jac(prob.f)
161+
return prob.f.jac(x, prob.p)
162+
elseif SciMLBase.has_vjp(prob.f)
163+
return prob.f.vjp(one(x), x, prob.p)
164+
elseif SciMLBase.has_jvp(prob.f)
165+
return prob.f.jvp(one(x), x, prob.p)
166+
end
167+
end
168+
return DI.derivative(prob.f, extras, autodiff, x, Constant(prob.p))
169+
end
170+
function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
171+
if J === nothing
172+
if extras === nothing
173+
if SciMLBase.isinplace(prob.f)
174+
J = similar(fx, length(fx), length(x))
175+
prob.f.jac(J, x, prob.p)
176+
return J
177+
else
178+
return prob.f.jac(x, prob.p)
179+
end
180+
end
181+
if SciMLBase.isinplace(prob)
182+
return DI.jacobian(prob.f, fx, extras, autodiff, x, Constant(prob.p))
183+
else
184+
return DI.jacobian(prob.f, extras, autodiff, x, Constant(prob.p))
185+
end
186+
end
187+
if extras === nothing
188+
if SciMLBase.isinplace(prob)
189+
prob.jac(J, x, prob.p)
190+
return J
191+
else
192+
return prob.jac(x, prob.p)
193+
end
194+
end
195+
if SciMLBase.isinplace(prob)
196+
DI.jacobian!(prob.f, J, fx, extras, autodiff, x, Constant(prob.p))
197+
else
198+
DI.jacobian!(prob.f, J, extras, autodiff, x, Constant(prob.p))
199+
end
200+
return J
201+
end
202+
135203
end

0 commit comments

Comments
 (0)