Skip to content

Commit c2438c7

Browse files
committed
fix: hessian through nonlinear solvers
1 parent 748fb09 commit c2438c7

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function NonlinearSolveBase.additional_incompatible_backend_check(
2020
end
2121

2222
Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
23-
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
23+
Utils.value(x::Dual) = ForwardDiff.value(x)
2424
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)
2525

2626
function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(

test/forward_ad_tests.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,57 @@ end
124124
end
125125
end
126126
end
127+
128+
@testitem "NLLS Hessian SciML/NonlinearSolve.jl#445" tags=[:core] begin
129+
using ForwardDiff, FiniteDiff
130+
131+
function objfn(F, init, params)
132+
th1, th2 = init
133+
px, py, l1, l2 = params
134+
F[1] = l1 * cos(th1) + l2 * cos(th1 + th2) - px
135+
F[2] = l1 * sin(th1) + l2 * sin(th1 + th2) - py
136+
return F
137+
end
138+
139+
function solve_nlprob(pxpy)
140+
px, py = pxpy
141+
theta1 = pi / 4
142+
theta2 = pi / 4
143+
initial_guess = [theta1; theta2]
144+
l1 = 60
145+
l2 = 60
146+
p = [px; py; l1; l2]
147+
prob = NonlinearLeastSquaresProblem(
148+
NonlinearFunction(objfn, resid_prototype = zeros(2)),
149+
initial_guess, p
150+
)
151+
resu = solve(
152+
prob,
153+
reltol = 1e-12, abstol = 1e-12
154+
)
155+
th1, th2 = resu.u
156+
cable1_base = [-90; 0; 0]
157+
cable2_base = [-150; 0; 0]
158+
cable3_base = [150; 0; 0]
159+
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
160+
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
161+
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
162+
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
163+
(cable1_top[2] - cable1_base[2])^2)
164+
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
165+
(cable23_top[2] - cable2_base[2])^2)
166+
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
167+
(cable23_top[2] - cable3_base[2])^2)
168+
return c1_length + c2_length + c3_length
169+
end
170+
171+
grad1 = ForwardDiff.gradient(solve_nlprob, [34.0, 87.0])
172+
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob, [34.0, 87.0])
173+
174+
@test grad1 grad2 atol = 1e-3
175+
176+
hess1 = ForwardDiff.hessian(solve_nlprob, [34.0, 87.0])
177+
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob, [34.0, 87.0])
178+
179+
@test hess1 hess2 atol = 1e-3
180+
end

0 commit comments

Comments
 (0)