Skip to content

Commit ff72509

Browse files
Merge pull request #3210 from AayushSabharwal/as/generate-control-function
fix: enable using `MTKParameters` with `generate_control_function`
2 parents 519038e + b0e4c7a commit ff72509

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

src/inputoutput.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,18 +235,20 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
235235
# TODO: add an optional check on the ordering of observed equations
236236
u = map(x -> time_varying_as_func(value(x), sys), dvs)
237237
p = map(x -> time_varying_as_func(value(x), sys), ps)
238+
p = reorder_parameters(sys, p)
238239
t = get_iv(sys)
239240

240241
# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
241242

242-
args = (u, inputs, p, t)
243+
args = (u, inputs, p..., t)
243244
if implicit_dae
244245
ddvs = map(Differential(get_iv(sys)), dvs)
245246
args = (ddvs, args...)
246247
end
247248
process = get_postprocess_fbody(sys)
248249
f = build_function(rhss, args...; postprocess_fbody = process,
249-
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps) .∘
250+
expression = Val{true}, wrap_code = wrap_mtkparameters(sys, false, 3) .∘
251+
wrap_array_vars(sys, rhss; dvs, ps, inputs) .∘
250252
wrap_parameter_dependencies(sys, false),
251253
kwargs...)
252254
f = eval_or_rgf.(f; eval_expression, eval_module)
@@ -426,7 +428,7 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kw
426428
augmented_sys = ODESystem(eqs, t, systems = [dsys], name = gensym(:outer))
427429
augmented_sys = extend(augmented_sys, sys)
428430

429-
(f_oop, f_ip), dvs, p = generate_control_function(augmented_sys, all_inputs,
431+
(f_oop, f_ip), dvs, p, io_sys = generate_control_function(augmented_sys, all_inputs,
430432
[d]; kwargs...)
431-
(f_oop, f_ip), augmented_sys, dvs, p
433+
(f_oop, f_ip), augmented_sys, dvs, p, io_sys
432434
end

test/input_output_handling.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,12 @@ eqs = [
160160
]
161161

162162
@named sys = ODESystem(eqs, t)
163-
f, dvs, ps = ModelingToolkit.generate_control_function(sys, simplify = true)
163+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)
164164

165165
@test isequal(dvs[], x)
166166
@test isempty(ps)
167167

168-
p = []
168+
p = nothing
169169
x = [rand()]
170170
u = [rand()]
171171
@test f[1](x, u, p, 1) == -x + u
@@ -221,10 +221,10 @@ eqs = [connect_sd(sd, mass1, mass2)
221221
@named _model = ODESystem(eqs, t)
222222
@named model = compose(_model, mass1, mass2, sd);
223223

224-
f, dvs, ps = ModelingToolkit.generate_control_function(model, simplify = true)
224+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(model, simplify = true)
225225
@test length(dvs) == 4
226226
@test length(ps) == length(parameters(model))
227-
p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps)
227+
p = MTKParameters(io_sys, [io_sys.u => NaN])
228228
x = ModelingToolkit.varmap_to_vars(
229229
merge(ModelingToolkit.defaults(model),
230230
Dict(D.(unknowns(model)) .=> 0.0)), dvs)
@@ -288,7 +288,7 @@ model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.i
288288
@named dmodel = Blocks.StateSpace([0.0], [1.0], [1.0], [0.0]) # An integrating disturbance
289289

290290
@named dist = ModelingToolkit.DisturbanceModel(model.torque.tau.u, dmodel)
291-
(f_oop, f_ip), outersys, dvs, p = ModelingToolkit.add_input_disturbance(model, dist)
291+
(f_oop, f_ip), outersys, dvs, p, io_sys = ModelingToolkit.add_input_disturbance(model, dist)
292292

293293
@unpack u, d = outersys
294294
matrices, ssys = linearize(outersys, [u, d], model_outputs)
@@ -302,7 +302,7 @@ x_add = ModelingToolkit.varmap_to_vars(merge(Dict(dvs .=> 0), Dict(dstate => 1))
302302
x0 = randn(5)
303303
x1 = copy(x0) + x_add # add disturbance state perturbation
304304
u = randn(1)
305-
pn = ModelingToolkit.varmap_to_vars(def, p)
305+
pn = MTKParameters(io_sys, [])
306306
xp0 = f_oop(x0, u, pn, 0)
307307
xp1 = f_oop(x1, u, pn, 0)
308308

@@ -401,3 +401,15 @@ end
401401
f, dvs, ps = ModelingToolkit.generate_control_function(sys, simplify = true)
402402
@test f[1]([0.5], nothing, nothing, 0.0) == [1.0]
403403
end
404+
405+
@testset "With callable symbolic" begin
406+
@variables x(t)=0 u(t)=0 [input = true]
407+
@parameters p(::Real) = (x -> 2x)
408+
eqs = [D(x) ~ -x + p(u)]
409+
@named sys = ODESystem(eqs, t)
410+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)
411+
p = MTKParameters(io_sys, [])
412+
u = [1.0]
413+
x = [1.0]
414+
@test_nowarn f[1](x, u, p, 0.0)
415+
end

0 commit comments

Comments
 (0)