Skip to content

Commit 7dd79c8

Browse files
committed
Add preface
1 parent 9d8de6e commit 7dd79c8

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

src/systems/abstractsystem.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ for prop in [
165165
:depvars
166166
:indvars
167167
:connection_type
168+
:preface
168169
]
169170
fname1 = Symbol(:get_, prop)
170171
fname2 = Symbol(:has_, prop)
@@ -357,6 +358,12 @@ function namespace_equation(eq::Equation, sys)
357358
_lhs ~ _rhs
358359
end
359360

361+
function namespace_assignment(eq::Assignment, sys)
362+
_lhs = namespace_expr(eq.lhs, sys)
363+
_rhs = namespace_expr(eq.rhs, sys)
364+
Assignment(_lhs, _rhs)
365+
end
366+
360367
function namespace_expr(O, sys) where {T}
361368
iv = independent_variable(sys)
362369
O = unwrap(O)
@@ -435,6 +442,25 @@ function equations(sys::ModelingToolkit.AbstractSystem)
435442
end
436443
end
437444

445+
function preface(sys::ModelingToolkit.AbstractSystem)
446+
has_preface(sys) || return nothing
447+
pre = get_preface(sys)
448+
systems = get_systems(sys)
449+
if isempty(systems)
450+
return pre
451+
else
452+
pres = pre === nothing ? [] : pre
453+
for sys in systems
454+
pre = get_preface(sys)
455+
pre === nothing && continue
456+
for eq in pre
457+
push!(pres, namespace_assignment(eq, sys))
458+
end
459+
end
460+
return isempty(pres) ? nothing : pres
461+
end
462+
end
463+
438464
function islinear(sys::AbstractSystem)
439465
rhs = [eq.rhs for eq equations(sys)]
440466

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,16 @@ function generate_function(
102102
p = map(x->time_varying_as_func(value(x), sys), ps)
103103
t = get_iv(sys)
104104

105+
if has_preface(sys)
106+
pre = ex -> Let(preface(sys), ex)
107+
else
108+
pre = ex -> ex
109+
end
110+
105111
if implicit_dae
106-
build_function(rhss, ddvs, u, p, t; kwargs...)
112+
build_function(rhss, ddvs, u, p, t; postprocess_fbody=pre, kwargs...)
107113
else
108-
build_function(rhss, u, p, t; kwargs...)
114+
build_function(rhss, u, p, t; postprocess_fbody=pre, kwargs...)
109115
end
110116
end
111117

src/systems/diffeqs/odesystem.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,19 @@ struct ODESystem <: AbstractODESystem
8080
"""
8181
structure::Any
8282
"""
83-
type: type of the system
83+
connection_type: type of the system
8484
"""
8585
connection_type::Any
86+
"""
87+
preface: injuect assignment statements before the evaluation of the RHS function.
88+
"""
89+
preface::Any
8690

87-
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
91+
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type, preface)
8892
check_variables(dvs,iv)
8993
check_parameters(ps,iv)
9094
check_equations(deqs,iv)
91-
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
95+
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type, preface)
9296
end
9397
end
9498

@@ -102,6 +106,7 @@ function ODESystem(
102106
default_p=Dict(),
103107
defaults=_merge(Dict(default_u0), Dict(default_p)),
104108
connection_type=nothing,
109+
preface=nothing,
105110
)
106111
name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
107112
deqs = collect(deqs)
@@ -135,7 +140,7 @@ function ODESystem(
135140
if length(unique(sysnames)) != length(sysnames)
136141
throw(ArgumentError("System names must be unique."))
137142
end
138-
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
143+
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type, preface)
139144
end
140145

141146
function ODESystem(eqs, iv=nothing; kwargs...)

src/variables.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Sym
5050
elseif container_type <: Tuple
5151
(vals...,)
5252
else
53+
vals = identity.(vals)
5354
SymbolicUtils.Code.create_array(container_type, eltype(vals), Val{1}(), Val(length(vals)), vals...)
5455
end
5556
end

0 commit comments

Comments
 (0)