Skip to content

Commit 19eafce

Browse files
feat: add generate_custom_function
1 parent 13badcc commit 19eafce

File tree

5 files changed

+211
-23
lines changed

5 files changed

+211
-23
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ export independent_variable, equations, controls,
233233
observed, full_equations
234234
export structural_simplify, expand_connections, linearize, linearization_function
235235

236-
export calculate_jacobian, generate_jacobian, generate_function
236+
export calculate_jacobian, generate_jacobian, generate_function, generate_custom_function
237237
export calculate_control_jacobian, generate_control_jacobian
238238
export calculate_tgrad, generate_tgrad
239239
export calculate_gradient, generate_gradient

src/systems/abstractsystem.jl

Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function calculate_hessian end
8282

8383
"""
8484
```julia
85-
generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = parameters(sys),
85+
generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = full_parameters(sys),
8686
expression = Val{true}; kwargs...)
8787
```
8888
@@ -93,7 +93,7 @@ function generate_tgrad end
9393

9494
"""
9595
```julia
96-
generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
96+
generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
9797
expression = Val{true}; kwargs...)
9898
```
9999
@@ -104,7 +104,7 @@ function generate_gradient end
104104

105105
"""
106106
```julia
107-
generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
107+
generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
108108
expression = Val{true}; sparse = false, kwargs...)
109109
```
110110
@@ -115,7 +115,7 @@ function generate_jacobian end
115115

116116
"""
117117
```julia
118-
generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
118+
generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
119119
expression = Val{true}; sparse = false, kwargs...)
120120
```
121121
@@ -126,7 +126,7 @@ function generate_factorized_W end
126126

127127
"""
128128
```julia
129-
generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
129+
generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
130130
expression = Val{true}; sparse = false, kwargs...)
131131
```
132132
@@ -137,14 +137,158 @@ function generate_hessian end
137137

138138
"""
139139
```julia
140-
generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
140+
generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
141141
expression = Val{true}; kwargs...)
142142
```
143143
144144
Generate a function to evaluate the system's equations.
145145
"""
146146
function generate_function end
147147

148+
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
149+
ps = parameters(sys); wrap_code = nothing, kwargs...)
150+
p = reorder_parameters(sys, ps)
151+
isscalar = !(exprs isa AbstractArray)
152+
if wrap_code === nothing
153+
wrap_code = isscalar ? identity : (identity, identity)
154+
end
155+
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
156+
157+
if is_time_dependent(sys)
158+
return build_function(exprs,
159+
dvs,
160+
p...,
161+
get_iv(sys);
162+
kwargs...,
163+
postprocess_fbody = pre,
164+
states = sol_states,
165+
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
166+
wrap_array_vars(sys, exprs; dvs)
167+
)
168+
else
169+
return build_function(exprs,
170+
dvs,
171+
p...;
172+
kwargs...,
173+
postprocess_fbody = pre,
174+
states = sol_states,
175+
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
176+
wrap_array_vars(sys, exprs; dvs)
177+
)
178+
end
179+
end
180+
181+
function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
182+
isscalar = !(exprs isa AbstractArray)
183+
allvars = if isscalar
184+
Set(get_variables(exprs))
185+
else
186+
union(get_variables.(exprs)...)
187+
end
188+
array_vars = Dict{Any, AbstractArray{Int}}()
189+
for (j, x) in enumerate(dvs)
190+
if istree(x) && operation(x) == getindex
191+
arg = arguments(x)[1]
192+
arg in allvars || continue
193+
inds = get!(() -> Int[], array_vars, arg)
194+
push!(inds, j)
195+
end
196+
end
197+
for (k, inds) in array_vars
198+
if inds == (inds′ = inds[1]:inds[end])
199+
array_vars[k] = inds′
200+
end
201+
end
202+
if isscalar
203+
function (expr)
204+
Func(
205+
expr.args,
206+
[],
207+
Let(
208+
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
209+
expr.body,
210+
false
211+
)
212+
)
213+
end
214+
else
215+
function (expr)
216+
Func(
217+
expr.args,
218+
[],
219+
Let(
220+
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
221+
expr.body,
222+
false
223+
)
224+
)
225+
end,
226+
function (expr)
227+
Func(
228+
expr.args,
229+
[],
230+
Let(
231+
[k :(view($(expr.args[2].name), $v)) for (k, v) in array_vars],
232+
expr.body,
233+
false
234+
)
235+
)
236+
end
237+
end
238+
end
239+
240+
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool)
241+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
242+
offset = Int(is_time_dependent(sys))
243+
244+
if isscalar
245+
function (expr)
246+
p = gensym(:p)
247+
Func(
248+
[
249+
expr.args[1],
250+
DestructuredArgs(
251+
[arg.name for arg in expr.args[2:(end - offset)]], p),
252+
(isone(offset) ? (expr.args[end],) : ())...
253+
],
254+
[],
255+
Let(expr.args[2:(end - offset)], expr.body, false)
256+
)
257+
end
258+
else
259+
function (expr)
260+
p = gensym(:p)
261+
Func(
262+
[
263+
expr.args[1],
264+
DestructuredArgs(
265+
[arg.name for arg in expr.args[2:(end - offset)]], p),
266+
(isone(offset) ? (expr.args[end],) : ())...
267+
],
268+
[],
269+
Let(expr.args[2:(end - offset)], expr.body, false)
270+
)
271+
end,
272+
function (expr)
273+
p = gensym(:p)
274+
Func(
275+
[
276+
expr.args[1],
277+
expr.args[2],
278+
DestructuredArgs(
279+
[arg.name for arg in expr.args[3:(end - offset)]], p),
280+
(isone(offset) ? (expr.args[end],) : ())...
281+
],
282+
[],
283+
Let(expr.args[3:(end - offset)], expr.body, false)
284+
)
285+
end
286+
end
287+
else
288+
identity
289+
end
290+
end
291+
148292
mutable struct Substitutions
149293
subs::Vector{Equation}
150294
deps::Vector{Vector{Int}}

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,22 +152,8 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
152152
ddvs = implicit_dae ? map(Differential(get_iv(sys)), dvs) :
153153
nothing,
154154
isdde = false,
155+
wrap_code = nothing,
155156
kwargs...)
156-
array_vars = Dict{Any, Vector{Int}}()
157-
for (j, x) in enumerate(dvs)
158-
if istree(x) && operation(x) == getindex
159-
arg = arguments(x)[1]
160-
inds = get!(() -> Int[], array_vars, arg)
161-
push!(inds, j)
162-
end
163-
end
164-
subs = Dict()
165-
for (k, inds) in array_vars
166-
if inds == (inds′ = inds[1]:inds[end])
167-
inds = inds′
168-
end
169-
subs[k] = term(view, Sym{Any}(Symbol("ˍ₋arg1")), inds)
170-
end
171157
if isdde
172158
eqs = delay_to_function(sys)
173159
else
@@ -180,13 +166,15 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
180166
# substitute x(t) by just x
181167
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
182168
[eq.rhs for eq in eqs]
183-
rhss = fast_substitute(rhss, subs)
184169

185170
# TODO: add an optional check on the ordering of observed equations
186171
u = map(x -> time_varying_as_func(value(x), sys), dvs)
187172
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
188173
t = get_iv(sys)
189174

175+
if wrap_code === nothing
176+
wrap_code = (identity, identity)
177+
end
190178
if isdde
191179
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...)
192180
else
@@ -195,10 +183,12 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
195183
if implicit_dae
196184
build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre,
197185
states = sol_states,
186+
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
198187
kwargs...)
199188
else
200189
build_function(rhss, u, p..., t; postprocess_fbody = pre,
201190
states = sol_states,
191+
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
202192
kwargs...)
203193
end
204194
end

test/generate_custom_function.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
3+
using IfElse
4+
5+
@variables x(t) y(t)[1:3]
6+
@parameters p1=1.0 p2[1:3]=[1.0, 2.0, 3.0] p3::Int=1 p4::Bool=false
7+
8+
sys = complete(ODESystem(Equation[], t, [x; y], [p1, p2, p3, p4]; name = :sys))
9+
u0 = [1.0, 2.0, 3.0, 4.0]
10+
p = ModelingToolkit.MTKParameters(sys, [])
11+
12+
fn1 = generate_custom_function(sys, x + y[1] + p1 + p2[1] + p3 * t; expression = Val(false))
13+
@test fn1(u0, p, 0.0) == 5.0
14+
15+
fn2 = generate_custom_function(
16+
sys, x + y[1] + p1 + p2[1] + p3 * t, [x], [p1, p2, p3]; expression = Val(false))
17+
@test fn1(u0, p, 0.0) == 5.0
18+
19+
fn3_oop, fn3_iip = generate_custom_function(
20+
sys, [x + y[2], y[3] + p2[2], p1 + p3, 3t]; expression = Val(false))
21+
22+
buffer = zeros(4)
23+
fn3_iip(buffer, u0, p, 1.0)
24+
@test buffer == [4.0, 6.0, 2.0, 3.0]
25+
@test fn3_oop(u0, p, 1.0) == [4.0, 6.0, 2.0, 3.0]
26+
27+
fn4 = generate_custom_function(sys, IfElse.ifelse(p4, p1, p2[2]); expression = Val(false))
28+
@test fn4(u0, p, 1.0) == 2.0
29+
fn5 = generate_custom_function(sys, IfElse.ifelse(!p4, p1, p2[2]); expression = Val(false))
30+
@test fn5(u0, p, 1.0) == 1.0
31+
32+
@variables x y[1:3]
33+
sys = complete(NonlinearSystem(Equation[], [x; y], [p1, p2, p3, p4]; name = :sys))
34+
35+
fn1 = generate_custom_function(sys, x + y[1] + p1 + p2[1] + p3; expression = Val(false))
36+
@test fn1(u0, p) == 6.0
37+
38+
fn2 = generate_custom_function(
39+
sys, x + y[1] + p1 + p2[1] + p3, [x], [p1, p2, p3]; expression = Val(false))
40+
@test fn1(u0, p) == 6.0
41+
42+
fn3_oop, fn3_iip = generate_custom_function(
43+
sys, [x + y[2], y[3] + p2[2], p1 + p3]; expression = Val(false))
44+
45+
buffer = zeros(3)
46+
fn3_iip(buffer, u0, p)
47+
@test buffer == [4.0, 6.0, 2.0]
48+
@test fn3_oop(u0, p, 1.0) == [4.0, 6.0, 2.0]
49+
50+
fn4 = generate_custom_function(sys, IfElse.ifelse(p4, p1, p2[2]); expression = Val(false))
51+
@test fn4(u0, p, 1.0) == 2.0
52+
fn5 = generate_custom_function(sys, IfElse.ifelse(!p4, p1, p2[2]); expression = Val(false))
53+
@test fn5(u0, p, 1.0) == 1.0

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ end
6363
@safetestset "FuncAffect Test" include("funcaffect.jl")
6464
@safetestset "Constants Test" include("constants.jl")
6565
@safetestset "Parameter Dependency Test" include("parameter_dependencies.jl")
66+
@safetestset "Generate Custom Function Test" include("generate_custom_function.jl")
6667
end
6768

6869
if GROUP == "All" || GROUP == "InterfaceII"

0 commit comments

Comments
 (0)