@@ -82,7 +82,7 @@ function calculate_hessian end
82
82
83
83
"""
84
84
```julia
85
- generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = parameters (sys),
85
+ generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = full_parameters (sys),
86
86
expression = Val{true}; kwargs...)
87
87
```
88
88
@@ -93,7 +93,7 @@ function generate_tgrad end
93
93
94
94
"""
95
95
```julia
96
- generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
96
+ generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
97
97
expression = Val{true}; kwargs...)
98
98
```
99
99
@@ -104,7 +104,7 @@ function generate_gradient end
104
104
105
105
"""
106
106
```julia
107
- generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
107
+ generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
108
108
expression = Val{true}; sparse = false, kwargs...)
109
109
```
110
110
@@ -115,7 +115,7 @@ function generate_jacobian end
115
115
116
116
"""
117
117
```julia
118
- generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
118
+ generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
119
119
expression = Val{true}; sparse = false, kwargs...)
120
120
```
121
121
@@ -126,7 +126,7 @@ function generate_factorized_W end
126
126
127
127
"""
128
128
```julia
129
- generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
129
+ generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
130
130
expression = Val{true}; sparse = false, kwargs...)
131
131
```
132
132
@@ -137,14 +137,158 @@ function generate_hessian end
137
137
138
138
"""
139
139
```julia
140
- generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
140
+ generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
141
141
expression = Val{true}; kwargs...)
142
142
```
143
143
144
144
Generate a function to evaluate the system's equations.
145
145
"""
146
146
function generate_function end
147
147
148
+ function generate_custom_function (sys:: AbstractSystem , exprs, dvs = unknowns (sys),
149
+ ps = parameters (sys); wrap_code = nothing , kwargs... )
150
+ p = reorder_parameters (sys, ps)
151
+ isscalar = ! (exprs isa AbstractArray)
152
+ if wrap_code === nothing
153
+ wrap_code = isscalar ? identity : (identity, identity)
154
+ end
155
+ pre, sol_states = get_substitutions_and_solved_unknowns (sys)
156
+
157
+ if is_time_dependent (sys)
158
+ return build_function (exprs,
159
+ dvs,
160
+ p... ,
161
+ get_iv (sys);
162
+ kwargs... ,
163
+ postprocess_fbody = pre,
164
+ states = sol_states,
165
+ wrap_code = wrap_code .∘ wrap_mtkparameters (sys, isscalar) .∘
166
+ wrap_array_vars (sys, exprs; dvs)
167
+ )
168
+ else
169
+ return build_function (exprs,
170
+ dvs,
171
+ p... ;
172
+ kwargs... ,
173
+ postprocess_fbody = pre,
174
+ states = sol_states,
175
+ wrap_code = wrap_code .∘ wrap_mtkparameters (sys, isscalar) .∘
176
+ wrap_array_vars (sys, exprs; dvs)
177
+ )
178
+ end
179
+ end
180
+
181
+ function wrap_array_vars (sys:: AbstractSystem , exprs; dvs = unknowns (sys))
182
+ isscalar = ! (exprs isa AbstractArray)
183
+ allvars = if isscalar
184
+ Set (get_variables (exprs))
185
+ else
186
+ union (get_variables .(exprs)... )
187
+ end
188
+ array_vars = Dict {Any, AbstractArray{Int}} ()
189
+ for (j, x) in enumerate (dvs)
190
+ if istree (x) && operation (x) == getindex
191
+ arg = arguments (x)[1 ]
192
+ arg in allvars || continue
193
+ inds = get! (() -> Int[], array_vars, arg)
194
+ push! (inds, j)
195
+ end
196
+ end
197
+ for (k, inds) in array_vars
198
+ if inds == (inds′ = inds[1 ]: inds[end ])
199
+ array_vars[k] = inds′
200
+ end
201
+ end
202
+ if isscalar
203
+ function (expr)
204
+ Func (
205
+ expr. args,
206
+ [],
207
+ Let (
208
+ [k ← :(view ($ (expr. args[1 ]. name), $ v)) for (k, v) in array_vars],
209
+ expr. body,
210
+ false
211
+ )
212
+ )
213
+ end
214
+ else
215
+ function (expr)
216
+ Func (
217
+ expr. args,
218
+ [],
219
+ Let (
220
+ [k ← :(view ($ (expr. args[1 ]. name), $ v)) for (k, v) in array_vars],
221
+ expr. body,
222
+ false
223
+ )
224
+ )
225
+ end ,
226
+ function (expr)
227
+ Func (
228
+ expr. args,
229
+ [],
230
+ Let (
231
+ [k ← :(view ($ (expr. args[2 ]. name), $ v)) for (k, v) in array_vars],
232
+ expr. body,
233
+ false
234
+ )
235
+ )
236
+ end
237
+ end
238
+ end
239
+
240
+ function wrap_mtkparameters (sys:: AbstractSystem , isscalar:: Bool )
241
+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
242
+ offset = Int (is_time_dependent (sys))
243
+
244
+ if isscalar
245
+ function (expr)
246
+ p = gensym (:p )
247
+ Func (
248
+ [
249
+ expr. args[1 ],
250
+ DestructuredArgs (
251
+ [arg. name for arg in expr. args[2 : (end - offset)]], p),
252
+ (isone (offset) ? (expr. args[end ],) : ()). ..
253
+ ],
254
+ [],
255
+ Let (expr. args[2 : (end - offset)], expr. body, false )
256
+ )
257
+ end
258
+ else
259
+ function (expr)
260
+ p = gensym (:p )
261
+ Func (
262
+ [
263
+ expr. args[1 ],
264
+ DestructuredArgs (
265
+ [arg. name for arg in expr. args[2 : (end - offset)]], p),
266
+ (isone (offset) ? (expr. args[end ],) : ()). ..
267
+ ],
268
+ [],
269
+ Let (expr. args[2 : (end - offset)], expr. body, false )
270
+ )
271
+ end ,
272
+ function (expr)
273
+ p = gensym (:p )
274
+ Func (
275
+ [
276
+ expr. args[1 ],
277
+ expr. args[2 ],
278
+ DestructuredArgs (
279
+ [arg. name for arg in expr. args[3 : (end - offset)]], p),
280
+ (isone (offset) ? (expr. args[end ],) : ()). ..
281
+ ],
282
+ [],
283
+ Let (expr. args[3 : (end - offset)], expr. body, false )
284
+ )
285
+ end
286
+ end
287
+ else
288
+ identity
289
+ end
290
+ end
291
+
148
292
mutable struct Substitutions
149
293
subs:: Vector{Equation}
150
294
deps:: Vector{Vector{Int}}
0 commit comments