1
1
module SciMLJacobianOperators
2
2
3
- using ADTypes: ADTypes, AutoSparse, AutoEnzyme
3
+ using ADTypes: ADTypes, AutoSparse
4
4
using ConcreteStructs: @concrete
5
5
using ConstructionBase: ConstructionBase
6
- using DifferentiationInterface: DifferentiationInterface
6
+ using DifferentiationInterface: DifferentiationInterface, Constant
7
7
using EnzymeCore: EnzymeCore
8
8
using FastClosures: @closure
9
9
using LinearAlgebra: LinearAlgebra
@@ -112,10 +112,10 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
112
112
iip = SciMLBase. isinplace (prob)
113
113
T = promote_type (eltype (u), eltype (fu))
114
114
115
- vjp_autodiff = set_function_as_const ( get_dense_ad (vjp_autodiff) )
115
+ vjp_autodiff = get_dense_ad (vjp_autodiff)
116
116
vjp_op = prepare_vjp (skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
117
117
118
- jvp_autodiff = set_function_as_const ( get_dense_ad (jvp_autodiff) )
118
+ jvp_autodiff = get_dense_ad (jvp_autodiff)
119
119
jvp_op = prepare_jvp (skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
120
120
121
121
output_cache = fu isa Number ? T (fu) : similar (fu, T)
@@ -295,23 +295,21 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
295
295
296
296
@assert autodiff!= = nothing " `vjp_autodiff` must be provided if `f` doesn't have \
297
297
analytic `vjp` or `jac`."
298
- # TODO : Once DI supports const params we can use `p`
299
- fₚ = SciMLBase. JacobianWrapper {SciMLBase.isinplace(f)} (f, prob. p)
300
298
if SciMLBase. isinplace (f)
301
- @assert DI. check_twoarg (autodiff) " Backend: $(autodiff) doesn't support in-place \
302
- problems."
299
+ @assert DI. check_inplace (autodiff) " Backend: $(autodiff) doesn't support in-place \
300
+ problems."
303
301
fu_cache = copy (fu)
304
- v_fake = copy (fu)
305
- di_extras = DI. prepare_pullback (fₚ, fu_cache, autodiff, u, v_fake)
302
+ di_extras = DI. prepare_pullback (f, fu_cache, autodiff, u, (fu,), Constant (prob. p))
306
303
return @closure (vJ, v, u, p) -> begin
307
- DI. pullback! (fₚ , fu_cache, reshape (vJ, size (u)), autodiff,
308
- u, reshape (v, size (fu_cache)), di_extras )
304
+ DI. pullback! (f , fu_cache, ( reshape (vJ, size (u)),), di_extras , autodiff,
305
+ u, ( reshape (v, size (fu_cache)),), Constant (p) )
309
306
return
310
307
end
311
308
else
312
- di_extras = DI. prepare_pullback (fₚ , autodiff, u, fu )
309
+ di_extras = DI. prepare_pullback (f , autodiff, u, (fu,), Constant (prob . p) )
313
310
return @closure (v, u, p) -> begin
314
- return DI. pullback (fₚ, autodiff, u, reshape (v, size (fu)), di_extras)
311
+ return only (DI. pullback (
312
+ f, di_extras, autodiff, u, (reshape (v, size (fu)),), Constant (p)))
315
313
end
316
314
end
317
315
end
@@ -342,23 +340,21 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
342
340
343
341
@assert autodiff!= = nothing " `jvp_autodiff` must be provided if `f` doesn't have \
344
342
analytic `vjp` or `jac`."
345
- # TODO : Once DI supports const params we can use `p`
346
- fₚ = SciMLBase. JacobianWrapper {SciMLBase.isinplace(f)} (f, prob. p)
347
343
if SciMLBase. isinplace (f)
348
- @assert DI. check_twoarg (autodiff) " Backend: $(autodiff) doesn't support in-place \
349
- problems."
344
+ @assert DI. check_inplace (autodiff) " Backend: $(autodiff) doesn't support in-place \
345
+ problems."
350
346
fu_cache = copy (fu)
351
- di_extras = DI. prepare_pushforward (fₚ , fu_cache, autodiff, u, u )
347
+ di_extras = DI. prepare_pushforward (f , fu_cache, autodiff, u, (u,), Constant (prob . p) )
352
348
return @closure (Jv, v, u, p) -> begin
353
- DI. pushforward! (
354
- fₚ, fu_cache, reshape (Jv, size (fu_cache)),
355
- autodiff, u, reshape (v, size (u)), di_extras)
349
+ DI. pushforward! (f, fu_cache, (reshape (Jv, size (fu_cache)),), di_extras,
350
+ autodiff, u, (reshape (v, size (u)),), Constant (p))
356
351
return
357
352
end
358
353
else
359
- di_extras = DI. prepare_pushforward (fₚ , autodiff, u, u )
354
+ di_extras = DI. prepare_pushforward (f , autodiff, u, (u,), Constant (prob . p) )
360
355
return @closure (v, u, p) -> begin
361
- return DI. pushforward (fₚ, autodiff, u, reshape (v, size (u)), di_extras)
356
+ return only (DI. pushforward (
357
+ f, di_extras, autodiff, u, (reshape (v, size (u)),), Constant (p)))
362
358
end
363
359
end
364
360
end
@@ -371,10 +367,8 @@ function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,
371
367
372
368
@assert autodiff!= = nothing " `autodiff` must be provided if `f` doesn't have \
373
369
analytic `vjp` or `jvp` or `jac`."
374
- # TODO : Once DI supports const params we can use `p`
375
- fₚ = Base. Fix2 (f, prob. p)
376
- di_extras = DI. prepare_derivative (fₚ, autodiff, u)
377
- return @closure (v, u, p) -> DI. derivative (fₚ, autodiff, u, di_extras) * v
370
+ di_extras = DI. prepare_derivative (f, autodiff, u, Constant (prob. p))
371
+ return @closure (v, u, p) -> DI. derivative (f, di_extras, autodiff, u, Constant (p)) * v
378
372
end
379
373
380
374
get_dense_ad (:: Nothing ) = nothing
@@ -386,12 +380,6 @@ function get_dense_ad(ad::AutoSparse)
386
380
return dense_ad
387
381
end
388
382
389
- # In our case we know that it is safe to mark the function as const
390
- set_function_as_const (ad) = ad
391
- function set_function_as_const (ad:: AutoEnzyme{M, Nothing} ) where {M}
392
- return AutoEnzyme (; ad. mode, function_annotation = EnzymeCore. Const)
393
- end
394
-
395
383
export JacobianOperator, VecJacOperator, JacVecOperator
396
384
export StatefulJacobianOperator
397
385
export StatefulJacobianNormalFormOperator
0 commit comments