Skip to content

Commit eb33a8a

Browse files
authored
Merge pull request #2355 from SciML/fb/discrete_timing
add component-based hybrid system test
2 parents 7685996 + b91fafe commit eb33a8a

File tree

5 files changed

+262
-51
lines changed

5 files changed

+262
-51
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ julia = "1.9"
113113
[extras]
114114
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
115115
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
116+
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
116117
ControlSystemsMTK = "687d7614-c7e5-45fc-bfc3-9ee385575c88"
117118
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
118119
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -136,4 +137,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
136137
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
137138

138139
[targets]
139-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg"]
140+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg"]

src/systems/clock_inference.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
150150
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
151151
offset = length(appended_parameters)
152152
affect_funs = []
153+
init_funs = []
153154
svs = []
154155
clocks = TimeDomain[]
155156
for (i, (sys, input)) in enumerate(zip(syss, inputs))
@@ -202,6 +203,18 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
202203
push!(save_vec.args, :(p[$(input_offset + i)]))
203204
end
204205
empty_disc = isempty(disc_range)
206+
207+
disc_init = :(function (p, t)
208+
d2c_obs = $disc_to_cont_obs
209+
d2c_view = view(p, $disc_to_cont_idxs)
210+
disc_state = view(p, $disc_range)
211+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
212+
end)
213+
214+
# @show disc_to_cont_idxs
215+
# @show cont_to_disc_idxs
216+
# @show disc_range
217+
205218
affect! = :(function (integrator, saved_values)
206219
@unpack u, p, t = integrator
207220
c2d_obs = $cont_to_disc_obs
@@ -212,27 +225,42 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
212225
d2c_view = view(p, $disc_to_cont_idxs)
213226
disc_state = view(p, $disc_range)
214227
disc = $disc
215-
# Write continuous into to discrete: handles `Sample`
216-
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
217-
# Write discrete into to continuous
218-
# get old discrete states
219-
copyto!(d2c_view, d2c_obs(disc_state, p, t))
228+
220229
push!(saved_values.t, t)
221230
push!(saved_values.saveval, $save_vec)
222-
# update discrete states
231+
232+
# Write continuous into to discrete: handles `Sample`
233+
# Write discrete into to continuous
234+
# Update discrete states
235+
236+
# At a tick, c2d must come first
237+
# state update comes in the middle
238+
# d2c comes last
239+
# @show t
240+
# @show "incoming", p
241+
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
242+
# @show "after c2d", p
223243
$empty_disc || disc(disc_state, disc_state, p, t)
244+
# @show "after state update", p
245+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
246+
# @show "after d2c", p
224247
end)
225248
sv = SavedValues(Float64, Vector{Float64})
226249
push!(affect_funs, affect!)
250+
push!(init_funs, disc_init)
227251
push!(svs, sv)
228252
end
229253
if eval_expression
230254
affects = map(affect_funs) do a
231255
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
232256
end
257+
inits = map(init_funs) do a
258+
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
259+
end
233260
else
234261
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
262+
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
235263
end
236264
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
237-
return affects, clocks, svs, appended_parameters, defaults
265+
return affects, inits, clocks, svs, appended_parameters, defaults
238266
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,9 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
945945
has_difference = has_difference,
946946
check_length, kwargs...)
947947
cbs = process_events(sys; callback, has_difference, kwargs...)
948+
inits = []
948949
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
949-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
950+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
950951
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
951952
if clock isa Clock
952953
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -976,7 +977,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
976977
if svs !== nothing
977978
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
978979
end
979-
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
980+
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
981+
if !isempty(inits)
982+
for init in inits
983+
init(prob.p, tspan[1])
984+
end
985+
end
986+
prob
980987
end
981988
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
982989

@@ -1045,8 +1052,9 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10451052
h = h_oop
10461053
u0 = h(p, tspan[1])
10471054
cbs = process_events(sys; callback, has_difference, kwargs...)
1055+
inits = []
10481056
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1049-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
1057+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
10501058
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10511059
if clock isa Clock
10521060
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1075,7 +1083,13 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10751083
if svs !== nothing
10761084
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
10771085
end
1078-
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1086+
prob = DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1087+
if !isempty(inits)
1088+
for init in inits
1089+
init(prob.p, tspan[1])
1090+
end
1091+
end
1092+
prob
10791093
end
10801094

10811095
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -1099,8 +1113,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10991113
h(p, t) = h_oop(p, t)
11001114
u0 = h(p, tspan[1])
11011115
cbs = process_events(sys; callback, has_difference, kwargs...)
1116+
inits = []
11021117
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1103-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
1118+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
11041119
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
11051120
if clock isa Clock
11061121
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1140,8 +1155,15 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11401155
else
11411156
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
11421157
end
1143-
SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype =
1158+
prob = SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
1159+
noise_rate_prototype =
11441160
noise_rate_prototype, kwargs1..., kwargs...)
1161+
if !isempty(inits)
1162+
for init in inits
1163+
init(prob.p, tspan[1])
1164+
end
1165+
end
1166+
prob
11451167
end
11461168

11471169
"""

0 commit comments

Comments
 (0)