Skip to content

Commit 0b1ce3a

Browse files
committed
fix: jacobian caching
1 parent 4a74ae0 commit 0b1ce3a

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

lib/NonlinearSolveBase/src/jacobian.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,16 @@ end
142142
## Numbers
143143
function (cache::JacobianCache{<:Number})(::Number, u, p = cache.p)
144144
cache.stats.njacs += 1
145-
SciMLBase.has_jac(cache.f) && return cache.f.jac(u, p)
146-
SciMLBase.has_vjp(cache.f) && return cache.f.vjp(one(u), u, p)
147-
SciMLBase.has_jvp(cache.f) && return cache.f.jvp(one(u), u, p)
148-
return DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
145+
cache.J = if SciMLBase.has_jac(cache.f)
146+
cache.f.jac(u, p)
147+
elseif SciMLBase.has_vjp(cache.f)
148+
cache.f.vjp(one(u), u, p)
149+
elseif SciMLBase.has_jvp(cache.f)
150+
cache.f.jvp(one(u), u, p)
151+
else
152+
DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
153+
end
154+
return cache.J
149155
end
150156

151157
## Actually Compute the Jacobian
@@ -156,12 +162,17 @@ function (cache::JacobianCache)(J::Union{AbstractMatrix, Nothing}, u, p = cache.
156162
cache.f.jac(J, u, p)
157163
else
158164
DI.jacobian!(
159-
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p))
165+
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p)
166+
)
160167
end
161168
return J
162169
else
163-
SciMLBase.has_jac(cache.f) && return cache.f.jac(u, p)
164-
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
170+
if SciMLBase.has_jac(cache.f)
171+
cache.J = cache.f.jac(u, p)
172+
else
173+
cache.J = DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
174+
end
175+
return cache.J
165176
end
166177
end
167178

test/core_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
dataOut = f([1, 2, 3], nothing) + 0.1 * randn(10, 1)
55

66
resid(x, p) = f(x, p) - dataOut
7-
jac(x, p) = [dataIn .^ 2 dataIn ones(10, 1)]
7+
jac(x, p) = [1:10 .^ 2 1:10 ones(10, 1)]
88
x0 = [1, 1, 1]
99

1010
prob = NonlinearLeastSquaresProblem(resid, x0)

0 commit comments

Comments
 (0)