Skip to content

Commit c73bcf4

Browse files
committed
fix: forwarddiff support
1 parent 26c77a6 commit c73bcf4

File tree

4 files changed

+18
-16
lines changed

4 files changed

+18
-16
lines changed

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
3737
sol = solve(newprob, alg, args...; kwargs...)
3838

3939
uu = sol.u
40-
Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
41-
Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
40+
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
41+
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
4242
z = -Jᵤ \ Jₚ
4343
pp = prob.p
4444
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
@@ -123,8 +123,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
123123
end
124124
end
125125

126-
Jₚ = nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
127-
Jᵤ = nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
126+
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
127+
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
128128
z = -Jᵤ \ Jₚ
129129
pp = prob.p
130130
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
@@ -140,7 +140,7 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
140140
return sol, partials
141141
end
142142

143-
function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
143+
function NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
144144
if SciMLBase.isinplace(prob)
145145
f2 = @closure p -> begin
146146
du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
@@ -159,7 +159,7 @@ function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
159159
end
160160
end
161161

162-
function nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
162+
function NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
163163
if SciMLBase.isinplace(prob)
164164
return ForwardDiff.jacobian(
165165
@closure((du, u)->f(du, u, p)), Utils.safe_similar(u), u)

lib/NonlinearSolveBase/src/public.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ function get_tolerance end
88
# Forward declarations of functions for forward mode AD
99
function nonlinearsolve_forwarddiff_solve end
1010
function nonlinearsolve_dual_solution end
11+
function nonlinearsolve_∂f_∂p end
12+
function nonlinearsolve_∂f_∂u end
1113

1214
# Nonlinear Solve Termination Conditions
1315
abstract type AbstractNonlinearTerminationMode end

src/NonlinearSolve.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ using LineSearch: LineSearch, AbstractLineSearchCache, LineSearchesJL, NoLineSea
2222
using LinearSolve: LinearSolve, QRFactorization, needs_concrete_A, AbstractFactorization,
2323
DefaultAlgorithmChoice, DefaultLinearSolver
2424
using MaybeInplace: @bb
25+
using NonlinearSolveBase: NonlinearSolveBase, nonlinearsolve_forwarddiff_solve,
26+
nonlinearsolve_dual_solution, nonlinearsolve_∂f_∂p,
27+
nonlinearsolve_∂f_∂u
2528
using Printf: @printf
2629
using Preferences: Preferences, @load_preference, @set_preferences!
2730
using RecursiveArrayTools: recursivecopy!

src/internal/forward_diff.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
# Not part of public API but helps reduce code duplication
2-
import SimpleNonlinearSolve: __nlsolve_ad, __nlsolve_dual_soln, __nlsolve_∂f_∂p,
3-
__nlsolve_∂f_∂u
4-
1+
# XXX: dispatch on `__solve` & `__init`
52
function SciMLBase.solve(
63
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
74
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
85
alg::Union{Nothing, AbstractNonlinearAlgorithm},
96
args...;
107
kwargs...) where {T, V, P, iip}
11-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
12-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
8+
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
9+
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
1310
return SciMLBase.build_solution(
1411
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1512
end
@@ -53,10 +50,10 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
5350
prob = cache.prob
5451

5552
uu = sol.u
56-
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
57-
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
53+
Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
54+
Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
5855

59-
z_arr = -f_x \ f_p
56+
z_arr = -Jᵤ \ Jₚ
6057

6158
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
6259
if cache.p isa Number
@@ -65,7 +62,7 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
6562
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
6663
end
6764

68-
dual_soln = __nlsolve_dual_soln(sol.u, partials, cache.p)
65+
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, cache.p)
6966
return SciMLBase.build_solution(
7067
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
7168
end

0 commit comments

Comments
 (0)