Skip to content

Commit 6b0002b

Browse files
authored
fix: hessian (#489)
* fix: hessian through nonlinear solvers * feat: extend gradient support for cached nlls
1 parent 748fb09 commit 6b0002b

File tree

8 files changed

+177
-100
lines changed

8 files changed

+177
-100
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ NLSolvers = "0.5"
8989
NLsolve = "4.5"
9090
NaNMath = "1"
9191
NonlinearProblemLibrary = "0.1.2"
92-
NonlinearSolveBase = "1"
92+
NonlinearSolveBase = "1.2"
9393
NonlinearSolveFirstOrder = "1"
9494
NonlinearSolveQuasiNewton = "1"
9595
NonlinearSolveSpectralMethods = "1"

lib/NonlinearSolveBase/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolveBase"
22
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.1.0"
4+
version = "1.2.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

Lines changed: 9 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using CommonSolve: solve
66
using DifferentiationInterface: DifferentiationInterface
77
using FastClosures: @closure
88
using ForwardDiff: ForwardDiff, Dual
9-
using LinearAlgebra: mul!
109
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
1110
NonlinearProblem, NonlinearLeastSquaresProblem, remake
1211

@@ -20,11 +19,14 @@ function NonlinearSolveBase.additional_incompatible_backend_check(
2019
end
2120

2221
Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
23-
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
22+
Utils.value(x::Dual) = ForwardDiff.value(x)
2423
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)
2524

2625
function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
27-
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
26+
prob::Union{
27+
IntervalNonlinearProblem, NonlinearProblem,
28+
ImmutableNonlinearProblem, NonlinearLeastSquaresProblem
29+
},
2830
alg, args...; kwargs...
2931
)
3032
p = Utils.value(prob.p)
@@ -35,98 +37,14 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
3537
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
3638
end
3739

38-
sol = solve(newprob, alg, args...; kwargs...)
39-
40-
uu = sol.u
41-
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
42-
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
43-
z = -Jᵤ \ Jₚ
44-
pp = prob.p
45-
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
46-
47-
if uu isa Number
48-
partials = sum(sumfun, zip(z, pp))
49-
elseif p isa Number
50-
partials = sumfun((z, pp))
51-
else
52-
partials = sum(sumfun, zip(eachcol(z), pp))
53-
end
54-
55-
return sol, partials
56-
end
57-
58-
function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
59-
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...
60-
)
61-
p = Utils.value(prob.p)
62-
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
6340
sol = solve(newprob, alg, args...; kwargs...)
6441
uu = sol.u
6542

66-
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
67-
# nested autodiff as the last resort
68-
if SciMLBase.has_vjp(prob.f)
69-
if SciMLBase.isinplace(prob)
70-
vjp_fn = @closure (du, u, p) -> begin
71-
resid = Utils.safe_similar(du, length(sol.resid))
72-
prob.f(resid, u, p)
73-
prob.f.vjp(du, resid, u, p)
74-
du .*= 2
75-
return nothing
76-
end
77-
else
78-
vjp_fn = @closure (u, p) -> begin
79-
resid = prob.f(u, p)
80-
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
81-
end
82-
end
83-
elseif SciMLBase.has_jac(prob.f)
84-
if SciMLBase.isinplace(prob)
85-
vjp_fn = @closure (du, u, p) -> begin
86-
J = Utils.safe_similar(du, length(sol.resid), length(u))
87-
prob.f.jac(J, u, p)
88-
resid = Utils.safe_similar(du, length(sol.resid))
89-
prob.f(resid, u, p)
90-
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
91-
return nothing
92-
end
93-
else
94-
vjp_fn = @closure (u, p) -> begin
95-
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
96-
end
97-
end
98-
else
99-
# For small problems, nesting ForwardDiff is actually quite fast
100-
autodiff = length(uu) + length(sol.resid) 50 ?
101-
NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) :
102-
AutoForwardDiff()
103-
104-
if SciMLBase.isinplace(prob)
105-
vjp_fn = @closure (du, u, p) -> begin
106-
resid = Utils.safe_similar(du, length(sol.resid))
107-
prob.f(resid, u, p)
108-
# Using `Constant` lead to dual ordering issues
109-
ff = @closure (du, u) -> prob.f(du, u, p)
110-
resid2 = copy(resid)
111-
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
112-
@. du *= 2
113-
return nothing
114-
end
115-
else
116-
vjp_fn = @closure (u, p) -> begin
117-
v = prob.f(u, p)
118-
# Using `Constant` lead to dual ordering issues
119-
ff = Base.Fix2(prob.f, p)
120-
res = only(DI.pullback(ff, autodiff, u, (v,)))
121-
ArrayInterface.can_setindex(res) || return 2 .* res
122-
@. res *= 2
123-
return res
124-
end
125-
end
126-
end
43+
fn = prob isa NonlinearLeastSquaresProblem ?
44+
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f
12745

128-
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
129-
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
46+
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, p)
47+
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, p)
13048
z = -Jᵤ \ Jₚ
13149
pp = prob.p
13250
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using ConcreteStructs: @concrete
55
using FastClosures: @closure
66
using Preferences: @load_preference, @set_preferences!
77

8-
using ADTypes: ADTypes, AbstractADType, AutoSparse, NoSparsityDetector,
8+
using ADTypes: ADTypes, AbstractADType, AutoSparse, AutoForwardDiff, NoSparsityDetector,
99
KnownJacobianSparsityDetector
1010
using Adapt: WrappedArray
1111
using ArrayInterface: ArrayInterface
@@ -25,7 +25,7 @@ using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator
2525
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
2626
using SymbolicIndexingInterface: SymbolicIndexingInterface
2727

28-
using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind
28+
using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul!
2929
using Markdown: @doc_str
3030
using Printf: @printf
3131

lib/NonlinearSolveBase/src/autodiff.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,65 @@ end
128128
is_finite_differences_backend(ad::AbstractADType) = false
129129
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true
130130
is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true
131+
132+
function nlls_generate_vjp_function(prob::NonlinearLeastSquaresProblem, sol, uu)
133+
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
134+
# nested autodiff as the last resort
135+
if SciMLBase.has_vjp(prob.f)
136+
if SciMLBase.isinplace(prob)
137+
return @closure (du, u, p) -> begin
138+
resid = Utils.safe_similar(du, length(sol.resid))
139+
prob.f.vjp(resid, u, p)
140+
prob.f.vjp(du, resid, u, p)
141+
du .*= 2
142+
return nothing
143+
end
144+
else
145+
return @closure (u, p) -> begin
146+
resid = prob.f(u, p)
147+
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
148+
end
149+
end
150+
elseif SciMLBase.has_jac(prob.f)
151+
if SciMLBase.isinplace(prob)
152+
return @closure (du, u, p) -> begin
153+
J = Utils.safe_similar(du, length(sol.resid), length(u))
154+
prob.f.jac(J, u, p)
155+
resid = Utils.safe_similar(du, length(sol.resid))
156+
prob.f(resid, u, p)
157+
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
158+
return nothing
159+
end
160+
else
161+
return @closure (u, p) -> begin
162+
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
163+
end
164+
end
165+
else
166+
# For small problems, nesting ForwardDiff is actually quite fast
167+
autodiff = length(uu) + length(sol.resid) 50 ?
168+
select_reverse_mode_autodiff(prob, nothing) : AutoForwardDiff()
169+
170+
if SciMLBase.isinplace(prob)
171+
return @closure (du, u, p) -> begin
172+
resid = Utils.safe_similar(du, length(sol.resid))
173+
prob.f(resid, u, p)
174+
# Using `Constant` lead to dual ordering issues
175+
ff = @closure (du, u) -> prob.f(du, u, p)
176+
resid2 = copy(resid)
177+
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
178+
@. du *= 2
179+
return nothing
180+
end
181+
else
182+
return @closure (u, p) -> begin
183+
v = prob.f(u, p)
184+
# Using `Constant` lead to dual ordering issues
185+
res = only(DI.pullback(Base.Fix2(prob.f, p), autodiff, u, (v,)))
186+
ArrayInterface.can_setindex(res) || return 2 .* res
187+
@. res *= 2
188+
return res
189+
end
190+
end
191+
end
192+
end

lib/NonlinearSolveBase/src/public.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ function nonlinearsolve_forwarddiff_solve end
1010
function nonlinearsolve_dual_solution end
1111
function nonlinearsolve_∂f_∂p end
1212
function nonlinearsolve_∂f_∂u end
13+
function nlls_generate_vjp_function end
1314

1415
# Nonlinear Solve Termination Conditions
1516
abstract type AbstractNonlinearTerminationMode end

src/forward_diff.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ function InternalAPI.reinit!(
4848
end
4949

5050
for algType in ALL_SOLVER_TYPES
51-
# XXX: Extend to DualNonlinearLeastSquaresProblem
5251
@eval function SciMLBase.__init(
53-
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...
52+
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
5453
)
5554
p = nodual_value(prob.p)
5655
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
@@ -64,10 +63,13 @@ end
6463
function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
6564
sol = solve!(cache.cache)
6665
prob = cache.prob
67-
6866
uu = sol.u
69-
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
70-
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
67+
68+
fn = prob isa NonlinearLeastSquaresProblem ?
69+
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f
70+
71+
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
72+
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)
7173

7274
z_arr = -Jᵤ \ Jₚ
7375

test/forward_ad_tests.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,97 @@ end
124124
end
125125
end
126126
end
127+
128+
@testitem "NLLS Hessian SciML/NonlinearSolve.jl#445" tags=[:core] begin
129+
using ForwardDiff, FiniteDiff
130+
131+
function objfn(F, init, params)
132+
th1, th2 = init
133+
px, py, l1, l2 = params
134+
F[1] = l1 * cos(th1) + l2 * cos(th1 + th2) - px
135+
F[2] = l1 * sin(th1) + l2 * sin(th1 + th2) - py
136+
return F
137+
end
138+
139+
function solve_nlprob(pxpy)
140+
px, py = pxpy
141+
theta1 = pi / 4
142+
theta2 = pi / 4
143+
initial_guess = [theta1; theta2]
144+
l1 = 60
145+
l2 = 60
146+
p = [px; py; l1; l2]
147+
prob = NonlinearLeastSquaresProblem(
148+
NonlinearFunction(objfn, resid_prototype = zeros(2)),
149+
initial_guess, p
150+
)
151+
resu = solve(
152+
prob,
153+
reltol = 1e-12, abstol = 1e-12
154+
)
155+
th1, th2 = resu.u
156+
cable1_base = [-90; 0; 0]
157+
cable2_base = [-150; 0; 0]
158+
cable3_base = [150; 0; 0]
159+
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
160+
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
161+
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
162+
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
163+
(cable1_top[2] - cable1_base[2])^2)
164+
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
165+
(cable23_top[2] - cable2_base[2])^2)
166+
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
167+
(cable23_top[2] - cable3_base[2])^2)
168+
return c1_length + c2_length + c3_length
169+
end
170+
171+
grad1 = ForwardDiff.gradient(solve_nlprob, [34.0, 87.0])
172+
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob, [34.0, 87.0])
173+
174+
@test grad1grad2 atol=1e-3
175+
176+
hess1 = ForwardDiff.hessian(solve_nlprob, [34.0, 87.0])
177+
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob, [34.0, 87.0])
178+
179+
@test hess1hess2 atol=1e-3
180+
181+
function solve_nlprob_with_cache(pxpy)
182+
px, py = pxpy
183+
theta1 = pi / 4
184+
theta2 = pi / 4
185+
initial_guess = [theta1; theta2]
186+
l1 = 60
187+
l2 = 60
188+
p = [px; py; l1; l2]
189+
prob = NonlinearLeastSquaresProblem(
190+
NonlinearFunction(objfn, resid_prototype = zeros(2)),
191+
initial_guess, p
192+
)
193+
cache = init(prob; reltol = 1e-12, abstol = 1e-12)
194+
resu = solve!(cache)
195+
th1, th2 = resu.u
196+
cable1_base = [-90; 0; 0]
197+
cable2_base = [-150; 0; 0]
198+
cable3_base = [150; 0; 0]
199+
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
200+
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
201+
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
202+
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
203+
(cable1_top[2] - cable1_base[2])^2)
204+
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
205+
(cable23_top[2] - cable2_base[2])^2)
206+
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
207+
(cable23_top[2] - cable3_base[2])^2)
208+
return c1_length + c2_length + c3_length
209+
end
210+
211+
grad1 = ForwardDiff.gradient(solve_nlprob_with_cache, [34.0, 87.0])
212+
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob_with_cache, [34.0, 87.0])
213+
214+
@test grad1grad2 atol=1e-3
215+
216+
hess1 = ForwardDiff.hessian(solve_nlprob_with_cache, [34.0, 87.0])
217+
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob_with_cache, [34.0, 87.0])
218+
219+
@test hess1hess2 atol=1e-3
220+
end

0 commit comments

Comments
 (0)