@@ -2,14 +2,102 @@ module MTKChainRulesCoreExt
2
2
3
3
import ModelingToolkit as MTK
4
4
import ChainRulesCore
5
- import ChainRulesCore: NoTangent
5
+ import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
6
6
7
7
function ChainRulesCore. rrule (:: Type{MTK.MTKParameters} , tunables, args... )
8
8
function mtp_pullback (dt)
9
+ dt = unthunk (dt)
9
10
(NoTangent (), dt. tunable[1 : length (tunables)],
10
11
ntuple (_ -> NoTangent (), length (args))... )
11
12
end
12
13
MTK. MTKParameters (tunables, args... ), mtp_pullback
13
14
end
14
15
16
+ notangent_or_else (:: NoTangent , _, x) = x
17
+ notangent_or_else (_, x, _) = x
18
+ notangent_fallback (x, y) = notangent_or_else (x, x, y)
19
+ reduce_to_notangent (x, y) = notangent_or_else (x, y, x)
20
+
21
+ function subset_idxs (idxs, portion, template)
22
+ ntuple (Val (length (template))) do subi
23
+ [Base. tail (idx. idx) for idx in idxs if idx. portion == portion && idx. idx[1 ] == subi]
24
+ end
25
+ end
26
+
27
+ selected_tangents (:: NoTangent , _) = ()
28
+ selected_tangents (:: ZeroTangent , _) = ZeroTangent ()
29
+ function selected_tangents (
30
+ tangents:: AbstractArray{T} , idxs:: Vector{Tuple{Int}} ) where {T <: Number }
31
+ selected_tangents (tangents, map (only, idxs))
32
+ end
33
+ function selected_tangents (tangents:: AbstractArray{T} , idxs... ) where {T <: Number }
34
+ newtangents = copy (tangents)
35
+ view (newtangents, idxs... ) .= zero (T)
36
+ newtangents
37
+ end
38
+ function selected_tangents (
39
+ tangents:: AbstractVector{T} , idxs) where {S <: Number , T <: AbstractArray{S} }
40
+ newtangents = copy (tangents)
41
+ for i in idxs
42
+ j, k... = i
43
+ if k == ()
44
+ newtangents[j] = zero (newtangents[j])
45
+ else
46
+ newtangents[j] = selected_tangents (newtangents[j], k... )
47
+ end
48
+ end
49
+ newtangents
50
+ end
51
+ function selected_tangents (tangents:: AbstractVector{T} , idxs) where {T <: AbstractArray }
52
+ newtangents = similar (tangents, Union{T, NoTangent})
53
+ copyto! (newtangents, tangents)
54
+ for i in idxs
55
+ j, k... = i
56
+ if k == ()
57
+ newtangents[j] = NoTangent ()
58
+ else
59
+ newtangents[j] = selected_tangents (newtangents[j], k... )
60
+ end
61
+ end
62
+ newtangents
63
+ end
64
+ function selected_tangents (
65
+ tangents:: Union{Tangent{<:Tuple}, Tangent{T, <:Tuple}} , idxs) where {T}
66
+ ntuple (Val (length (tangents))) do i
67
+ selected_tangents (tangents[i], idxs[i])
68
+ end
69
+ end
70
+
71
+ function ChainRulesCore. rrule (
72
+ :: typeof (MTK. remake_buffer), indp, oldbuf:: MTK.MTKParameters , idxs, vals)
73
+ if idxs isa AbstractSet
74
+ idxs = collect (idxs)
75
+ end
76
+ idxs = map (idxs) do i
77
+ i isa MTK. ParameterIndex ? i : MTK. parameter_index (indp, i)
78
+ end
79
+ newbuf = MTK. remake_buffer (indp, oldbuf, idxs, vals)
80
+ tunable_idxs = reduce (
81
+ vcat, (idx. idx for idx in idxs if idx. portion isa MTK. SciMLStructures. Tunable))
82
+ disc_idxs = subset_idxs (idxs, MTK. SciMLStructures. Discrete (), oldbuf. discrete)
83
+ const_idxs = subset_idxs (idxs, MTK. SciMLStructures. Constants (), oldbuf. constant)
84
+ nn_idxs = subset_idxs (idxs, MTK. NONNUMERIC_PORTION, oldbuf. nonnumeric)
85
+
86
+ function remake_buffer_pullback (buf′)
87
+ buf′ = unthunk (buf′)
88
+ f′ = NoTangent ()
89
+ indp′ = NoTangent ()
90
+
91
+ tunable = selected_tangents (buf′. tunable, tunable_idxs)
92
+ discrete = selected_tangents (buf′. discrete, disc_idxs)
93
+ constant = selected_tangents (buf′. constant, const_idxs)
94
+ nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
95
+ oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
96
+ idxs′ = NoTangent ()
97
+ vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
98
+ return f′, indp′, oldbuf′, idxs′, vals′
99
+ end
100
+ newbuf, remake_buffer_pullback
101
+ end
102
+
15
103
end
0 commit comments