Skip to content

Commit 153244f

Browse files
committed
feat: use contexts from DifferentiationInterface.jl
1 parent c2706b8 commit 153244f

File tree

3 files changed

+26
-38
lines changed

3 files changed

+26
-38
lines changed

lib/SciMLJacobianOperators/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ADTypes = "1.8.1"
1919
Aqua = "0.8.7"
2020
ConcreteStructs = "0.2.3"
2121
ConstructionBase = "1.5"
22-
DifferentiationInterface = "0.5"
22+
DifferentiationInterface = "0.6"
2323
Enzyme = "0.12, 0.13"
2424
EnzymeCore = "0.7, 0.8"
2525
ExplicitImports = "1.9.0"

lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module SciMLJacobianOperators
22

3-
using ADTypes: ADTypes, AutoSparse, AutoEnzyme
3+
using ADTypes: ADTypes, AutoSparse
44
using ConcreteStructs: @concrete
55
using ConstructionBase: ConstructionBase
6-
using DifferentiationInterface: DifferentiationInterface
6+
using DifferentiationInterface: DifferentiationInterface, Constant
77
using EnzymeCore: EnzymeCore
88
using FastClosures: @closure
99
using LinearAlgebra: LinearAlgebra
@@ -112,10 +112,10 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
112112
iip = SciMLBase.isinplace(prob)
113113
T = promote_type(eltype(u), eltype(fu))
114114

115-
vjp_autodiff = set_function_as_const(get_dense_ad(vjp_autodiff))
115+
vjp_autodiff = get_dense_ad(vjp_autodiff)
116116
vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
117117

118-
jvp_autodiff = set_function_as_const(get_dense_ad(jvp_autodiff))
118+
jvp_autodiff = get_dense_ad(jvp_autodiff)
119119
jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
120120

121121
output_cache = fu isa Number ? T(fu) : similar(fu, T)
@@ -295,23 +295,21 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
295295

296296
@assert autodiff!==nothing "`vjp_autodiff` must be provided if `f` doesn't have \
297297
analytic `vjp` or `jac`."
298-
# TODO: Once DI supports const params we can use `p`
299-
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
300298
if SciMLBase.isinplace(f)
301-
@assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \
302-
problems."
299+
@assert DI.check_inplace(autodiff) "Backend: $(autodiff) doesn't support in-place \
300+
problems."
303301
fu_cache = copy(fu)
304-
v_fake = copy(fu)
305-
di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake)
302+
di_extras = DI.prepare_pullback(f, fu_cache, autodiff, u, (fu,), Constant(prob.p))
306303
return @closure (vJ, v, u, p) -> begin
307-
DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff,
308-
u, reshape(v, size(fu_cache)), di_extras)
304+
DI.pullback!(f, fu_cache, (reshape(vJ, size(u)),), di_extras, autodiff,
305+
u, (reshape(v, size(fu_cache)),), Constant(p))
309306
return
310307
end
311308
else
312-
di_extras = DI.prepare_pullback(fₚ, autodiff, u, fu)
309+
di_extras = DI.prepare_pullback(f, autodiff, u, (fu,), Constant(prob.p))
313310
return @closure (v, u, p) -> begin
314-
return DI.pullback(fₚ, autodiff, u, reshape(v, size(fu)), di_extras)
311+
return only(DI.pullback(
312+
f, di_extras, autodiff, u, (reshape(v, size(fu)),), Constant(p)))
315313
end
316314
end
317315
end
@@ -342,23 +340,21 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
342340

343341
@assert autodiff!==nothing "`jvp_autodiff` must be provided if `f` doesn't have \
344342
analytic `vjp` or `jac`."
345-
# TODO: Once DI supports const params we can use `p`
346-
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
347343
if SciMLBase.isinplace(f)
348-
@assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \
349-
problems."
344+
@assert DI.check_inplace(autodiff) "Backend: $(autodiff) doesn't support in-place \
345+
problems."
350346
fu_cache = copy(fu)
351-
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
347+
di_extras = DI.prepare_pushforward(f, fu_cache, autodiff, u, (u,), Constant(prob.p))
352348
return @closure (Jv, v, u, p) -> begin
353-
DI.pushforward!(
354-
fₚ, fu_cache, reshape(Jv, size(fu_cache)),
355-
autodiff, u, reshape(v, size(u)), di_extras)
349+
DI.pushforward!(f, fu_cache, (reshape(Jv, size(fu_cache)),), di_extras,
350+
autodiff, u, (reshape(v, size(u)),), Constant(p))
356351
return
357352
end
358353
else
359-
di_extras = DI.prepare_pushforward(fₚ, autodiff, u, u)
354+
di_extras = DI.prepare_pushforward(f, autodiff, u, (u,), Constant(prob.p))
360355
return @closure (v, u, p) -> begin
361-
return DI.pushforward(fₚ, autodiff, u, reshape(v, size(u)), di_extras)
356+
return only(DI.pushforward(
357+
f, di_extras, autodiff, u, (reshape(v, size(u)),), Constant(p)))
362358
end
363359
end
364360
end
@@ -371,10 +367,8 @@ function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,
371367

372368
@assert autodiff!==nothing "`autodiff` must be provided if `f` doesn't have \
373369
analytic `vjp` or `jvp` or `jac`."
374-
# TODO: Once DI supports const params we can use `p`
375-
fₚ = Base.Fix2(f, prob.p)
376-
di_extras = DI.prepare_derivative(fₚ, autodiff, u)
377-
return @closure (v, u, p) -> DI.derivative(fₚ, autodiff, u, di_extras) * v
370+
di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p))
371+
return @closure (v, u, p) -> DI.derivative(f, di_extras, autodiff, u, Constant(p)) * v
378372
end
379373

380374
get_dense_ad(::Nothing) = nothing
@@ -386,12 +380,6 @@ function get_dense_ad(ad::AutoSparse)
386380
return dense_ad
387381
end
388382

389-
# In our case we know that it is safe to mark the function as const
390-
set_function_as_const(ad) = ad
391-
function set_function_as_const(ad::AutoEnzyme{M, Nothing}) where {M}
392-
return AutoEnzyme(; ad.mode, function_annotation = EnzymeCore.Const)
393-
end
394-
395383
export JacobianOperator, VecJacOperator, JacVecOperator
396384
export StatefulJacobianOperator
397385
export StatefulJacobianNormalFormOperator

lib/SciMLJacobianOperators/test/core_tests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
AutoEnzyme(),
88
AutoEnzyme(; mode = Enzyme.Reverse),
99
AutoZygote(),
10-
AutoReverseDiff(),
10+
# AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503
1111
AutoTracker(),
1212
AutoFiniteDiff()
1313
]
@@ -91,7 +91,7 @@ end
9191
reverse_ADs = [
9292
AutoEnzyme(),
9393
AutoEnzyme(; mode = Enzyme.Reverse),
94-
AutoReverseDiff(),
94+
# AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503
9595
AutoFiniteDiff()
9696
]
9797

@@ -182,7 +182,7 @@ end
182182
AutoEnzyme(; mode = Enzyme.Reverse),
183183
AutoZygote(),
184184
AutoTracker(),
185-
AutoReverseDiff(),
185+
# AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503
186186
AutoFiniteDiff()
187187
]
188188

0 commit comments

Comments
 (0)