Skip to content

Commit 5a47dd6

Browse files
authored
Merge pull request #1157 from SciML/sm/refactor
More robust array symbolic handling
2 parents 89fdee8 + c5e4974 commit 5a47dd6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+316
-276
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "5.26.0"
4+
version = "6.0.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -73,7 +73,7 @@ Setfield = "0.7"
7373
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
7474
StaticArrays = "0.10, 0.11, 0.12, 1.0"
7575
SymbolicUtils = "0.12, 0.13"
76-
Symbolics = "1.4.1"
76+
Symbolics = "2.0"
7777
UnPack = "0.1, 1.0"
7878
Unitful = "1.1"
7979
julia = "1.2"

examples/rc_model.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ rc_eqs = [
1414
connect(capacitor.n, source.n, ground.g)
1515
]
1616

17-
@named rc_model = compose(ODESystem(rc_eqs, t), [resistor, capacitor, source, ground])
17+
@named rc_model = ODESystem(rc_eqs, t)
18+
rc_model = compose(rc_model, [resistor, capacitor, source, ground])

examples/serial_inductor.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ eqs = [
1313
connect(source.n, inductor2.n, ground.g)
1414
]
1515

16-
@named ll_model = compose(ODESystem(eqs, t), source, resistor, inductor1, inductor2, ground)
16+
@named ll_model = ODESystem(eqs, t)
17+
ll_model = compose(ll_model, [source, resistor, inductor1, inductor2, ground])

src/ModelingToolkit.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ using RecursiveArrayTools
3333
import SymbolicUtils
3434
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
3535
Symbolic, Term, Add, Mul, Pow, Sym, FnType,
36-
@rule, Rewriters, substitute
36+
@rule, Rewriters, substitute, metadata
3737
using SymbolicUtils.Code
3838
import SymbolicUtils.Code: toexpr
3939
import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint
@@ -43,7 +43,8 @@ using Reexport
4343
@reexport using Symbolics
4444
export @derivatives
4545
using Symbolics: _parse_vars, value, @derivatives, get_variables,
46-
exprs_occur_in, solve_for, build_expr, unwrap, wrap
46+
exprs_occur_in, solve_for, build_expr, unwrap, wrap,
47+
VariableSource, getname, variable
4748
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
4849
jacobian_sparsity, islinear, _iszero, _isone,
4950
tosymbol, lower_varname, diff2term, var_from_nested_derivative,

src/systems/abstractsystem.jl

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,6 @@ function generate_function end
131131

132132
Base.nameof(sys::AbstractSystem) = getfield(sys, :name)
133133

134-
function getname(t)
135-
if istree(t)
136-
operation(t) isa Sym ? getname(operation(t)) : error("Cannot get name of $t")
137-
else
138-
nameof(t)
139-
end
140-
end
141134
#Deprecated
142135
function independent_variable(sys::AbstractSystem)
143136
Base.depwarn("`independent_variable` is deprecated. Use `get_iv` or `independent_variables` instead.",:independent_variable)
@@ -193,6 +186,7 @@ for prop in [
193186
:depvars
194187
:indvars
195188
:connection_type
189+
:preface
196190
]
197191
fname1 = Symbol(:get_, prop)
198192
fname2 = Symbol(:has_, prop)
@@ -248,37 +242,35 @@ function Base.propertynames(sys::AbstractSystem; private=false)
248242
end
249243
end
250244

251-
Base.getproperty(sys::AbstractSystem, name::Symbol; namespace=true) = getvar(sys, name; namespace=namespace)
245+
Base.getproperty(sys::AbstractSystem, name::Symbol; namespace=true) = wrap(getvar(sys, name; namespace=namespace))
252246
function getvar(sys::AbstractSystem, name::Symbol; namespace=false)
253-
sysname = nameof(sys)
254247
systems = get_systems(sys)
255248
if isdefined(sys, name)
256249
Base.depwarn("`sys.name` like `sys.$name` is deprecated. Use getters like `get_$name` instead.", "sys.$name")
257250
return getfield(sys, name)
258251
elseif !isempty(systems)
259252
i = findfirst(x->nameof(x)==name, systems)
260253
if i !== nothing
261-
return namespace ? rename(systems[i], renamespace(sysname, name)) : systems[i]
254+
return namespace ? rename(systems[i], renamespace(sys, name)) : systems[i]
262255
end
263256
end
264257

265258
if has_var_to_name(sys)
266259
avs = get_var_to_name(sys)
267260
v = get(avs, name, nothing)
268-
v === nothing || return namespace ? renamespace(sysname, v, name) : v
269-
261+
v === nothing || return namespace ? renamespace(sys, v) : v
270262
else
271263
sts = get_states(sys)
272264
i = findfirst(x->getname(x) == name, sts)
273265
if i !== nothing
274-
return namespace ? renamespace(sysname,sts[i]) : sts[i]
266+
return namespace ? renamespace(sys, sts[i]) : sts[i]
275267
end
276268

277269
if has_ps(sys)
278270
ps = get_ps(sys)
279271
i = findfirst(x->getname(x) == name,ps)
280272
if i !== nothing
281-
return namespace ? renamespace(sysname,ps[i]) : ps[i]
273+
return namespace ? renamespace(sys, ps[i]) : ps[i]
282274
end
283275
end
284276
end
@@ -290,7 +282,7 @@ function getvar(sys::AbstractSystem, name::Symbol; namespace=false)
290282
obs = get_observed(sys)
291283
i = findfirst(x->getname(x.lhs)==name,obs)
292284
if i !== nothing
293-
return namespace ? renamespace(sysname,obs[i]) : obs[i]
285+
return namespace ? renamespace(sys, obs[i]) : obs[i]
294286
end
295287
end
296288

@@ -330,20 +322,39 @@ ParentScope(sym::Union{Num, Symbolic}) = setmetadata(sym, SymScope, ParentScope(
330322
struct GlobalScope <: SymScope end
331323
GlobalScope(sym::Union{Num, Symbolic}) = setmetadata(sym, SymScope, GlobalScope())
332324

333-
function renamespace(namespace, x, name=nothing)
325+
renamespace(sys, eq::Equation) = namespace_equation(eq, sys)
326+
327+
function _renamespace(sys, x)
328+
v = unwrap(x)
329+
330+
if istree(v) && symtype(operation(v)) <: FnType
331+
ov = metadata(operation(v), metadata(v))
332+
return similarterm(v, renamespace(sys, ov), arguments(v), symtype(v), metadata=metadata(v))
333+
end
334+
335+
if v isa Namespace
336+
sysp, v = v.parent, v.named
337+
sysn = Symbol(getname(sys), :., getname(sysp))
338+
sys = sys isa AbstractSystem ? rename(sysp, sysn) : sysn
339+
end
340+
341+
Namespace(sys, v)
342+
end
343+
344+
function renamespace(sys, x)
334345
x = unwrap(x)
335346
if x isa Symbolic
336347
let scope = getmetadata(x, SymScope, LocalScope())
337348
if scope isa LocalScope
338-
rename(x, renamespace(namespace, name === nothing ? getname(x) : name))
349+
_renamespace(sys, x)
339350
elseif scope isa ParentScope
340351
setmetadata(x, SymScope, scope.parent)
341352
else # GlobalScope
342353
x
343354
end
344355
end
345356
else
346-
Symbol(namespace,:₊,x)
357+
Symbol(getname(sys), :., x)
347358
end
348359
end
349360

@@ -353,33 +364,38 @@ namespace_controls(sys::AbstractSystem) = controls(sys, controls(sys))
353364

354365
function namespace_defaults(sys)
355366
defs = defaults(sys)
356-
Dict((isparameter(k) ? parameters(sys, k) : states(sys, k)) => namespace_expr(defs[k], nameof(sys), independent_variables(sys)) for k in keys(defs))
367+
Dict((isparameter(k) ? parameters(sys, k) : states(sys, k)) => namespace_expr(defs[k], sys) for k in keys(defs))
357368
end
358369

359370
function namespace_equations(sys::AbstractSystem)
360371
eqs = equations(sys)
361372
isempty(eqs) && return Equation[]
362-
ivs = independent_variables(sys)
363-
map(eq -> namespace_equation(eq, nameof(sys), ivs), eqs)
373+
map(eq->namespace_equation(eq, sys), eqs)
364374
end
365375

366-
function namespace_equation(eq::Equation, name, ivs)
367-
_lhs = namespace_expr(eq.lhs, name, ivs)
368-
_rhs = namespace_expr(eq.rhs, name, ivs)
376+
function namespace_equation(eq::Equation, sys)
377+
_lhs = namespace_expr(eq.lhs, sys)
378+
_rhs = namespace_expr(eq.rhs, sys)
369379
_lhs ~ _rhs
370380
end
371381

372-
function namespace_expr(O::Sym, name, ivs)
373-
any(isequal(O), ivs) ? O : renamespace(name, O)
382+
function namespace_assignment(eq::Assignment, sys)
383+
_lhs = namespace_expr(eq.lhs, sys)
384+
_rhs = namespace_expr(eq.rhs, sys)
385+
Assignment(_lhs, _rhs)
374386
end
375387

376-
_symparam(s::Symbolic{T}) where {T} = T
377-
function namespace_expr(O, name, ivs) where {T}
378-
O = value(O)
379-
if istree(O)
380-
renamed = map(a -> namespace_expr(a, name, ivs), arguments(O))
381-
if operation(O) isa Sym
382-
renamespace(name, O)
388+
function namespace_expr(O, sys) where {T}
389+
ivs = independent_variables(sys)
390+
O = unwrap(O)
391+
if any(isequal(O), ivs)
392+
return O
393+
elseif isvariable(O)
394+
renamespace(sys, O)
395+
elseif istree(O)
396+
renamed = map(a->namespace_expr(a, sys), arguments(O))
397+
if symtype(operation(O)) <: FnType
398+
renamespace(sys, O)
383399
else
384400
similarterm(O, operation(O), renamed)
385401
end
@@ -409,13 +425,12 @@ function controls(sys::AbstractSystem)
409425
end
410426

411427
function observed(sys::AbstractSystem)
412-
ivs = independent_variables(sys)
413428
obs = get_observed(sys)
414429
systems = get_systems(sys)
415430
[obs;
416431
reduce(vcat,
417-
(map(o -> namespace_equation.(o, nameof(s), ivs), observed(s)) for s in systems),
418-
init = Equation[])]
432+
(map(o->namespace_equation(o, s), observed(s)) for s in systems),
433+
init=Equation[])]
419434
end
420435

421436
Base.@deprecate default_u0(x) defaults(x) false
@@ -426,7 +441,7 @@ function defaults(sys::AbstractSystem)
426441
isempty(systems) ? defs : mapreduce(namespace_defaults, merge, systems; init=defs)
427442
end
428443

429-
states(sys::AbstractSystem, v) = renamespace(nameof(sys), v)
444+
states(sys::AbstractSystem, v) = renamespace(sys, v)
430445
parameters(sys::AbstractSystem, v) = toparam(states(sys, v))
431446
for f in [:states, :parameters]
432447
@eval $f(sys::AbstractSystem, vs::AbstractArray) = map(v->$f(sys, v), vs)
@@ -448,6 +463,25 @@ function equations(sys::ModelingToolkit.AbstractSystem)
448463
end
449464
end
450465

466+
function preface(sys::ModelingToolkit.AbstractSystem)
467+
has_preface(sys) || return nothing
468+
pre = get_preface(sys)
469+
systems = get_systems(sys)
470+
if isempty(systems)
471+
return pre
472+
else
473+
pres = pre === nothing ? [] : pre
474+
for sys in systems
475+
pre = get_preface(sys)
476+
pre === nothing && continue
477+
for eq in pre
478+
push!(pres, namespace_assignment(eq, sys))
479+
end
480+
end
481+
return isempty(pres) ? nothing : pres
482+
end
483+
end
484+
451485
function islinear(sys::AbstractSystem)
452486
rhs = [eq.rhs for eq equations(sys)]
453487

@@ -913,7 +947,11 @@ function Base.hash(sys::AbstractSystem, s::UInt)
913947
s = foldr(hash, get_systems(sys), init=s)
914948
s = foldr(hash, get_states(sys), init=s)
915949
s = foldr(hash, get_ps(sys), init=s)
916-
s = foldr(hash, get_eqs(sys), init=s)
950+
if sys isa OptimizationSystem
951+
s = hash(get_op(sys), s)
952+
else
953+
s = foldr(hash, get_eqs(sys), init=s)
954+
end
917955
s = foldr(hash, get_observed(sys), init=s)
918956
s = hash(independent_variables(sys), s)
919957
return s
@@ -935,7 +973,7 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol=nameo
935973
else
936974
throw("Extending multivariate systems is not supported")
937975
end
938-
976+
939977
eqs = union(equations(basesys), equations(sys))
940978
sts = union(states(basesys), states(sys))
941979
ps = union(parameters(basesys), parameters(sys))

src/systems/control/controlsystem.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls
8787
default_u0=Dict(),
8888
default_p=Dict(),
8989
defaults=_merge(Dict(default_u0), Dict(default_p)),
90-
name=gensym(:ControlSystem))
90+
name=nothing)
91+
name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
9192
if !(isempty(default_u0) && isempty(default_p))
9293
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ControlSystem, force=true)
9394
end
@@ -165,7 +166,7 @@ function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
165166
L = @RuntimeGeneratedFunction(build_function(lo,sts,ctr,ps,iv,conv = ModelingToolkit.ControlToExpr(sys)))
166167

167168
var(n, i...) = var(nameof(n), i...)
168-
var(n::Symbol, i...) = Sym{FnType{Tuple{symtype(iv)}, Real}}(nameof(Variable(n, i...)))
169+
var(n::Symbol, i...) = variable(n, i..., T=FnType)
169170
# Expand out all of the variables in time and by stages
170171
timed_vars = [[var(operation(x),i)(iv) for i in 1:n+1] for x in states(sys)]
171172
k_vars = [[var(Symbol(:ᵏ,nameof(operation(x))),i,j)(iv) for i in 1:m, j in 1:n] for x in states(sys)]
@@ -196,5 +197,5 @@ function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
196197
equalities = vcat(stages,updates,control_equality)
197198
opt_states = vcat(reduce(vcat,reduce(vcat,states_timeseries)),reduce(vcat,reduce(vcat,k_timeseries)),reduce(vcat,reduce(vcat,control_timeseries)))
198199

199-
OptimizationSystem(reduce(+,losses, init=0),opt_states,ps,equality_constraints = equalities)
200+
OptimizationSystem(reduce(+,losses, init=0),opt_states,ps,equality_constraints = equalities, name=nameof(sys))
200201
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function calculate_tgrad(sys::AbstractODESystem;
1010
xs = states(sys)
1111
rule = Dict(map((x, xt) -> xt=>x, detime_dvs.(xs), xs))
1212
rhs = substitute.(rhs, Ref(rule))
13-
tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r), simplify) for r in rhs]
13+
tgrad = [expand_derivatives(Differential(iv)(r), simplify) for r in rhs]
1414
reverse_rule = Dict(map((x, xt) -> x=>xt, detime_dvs.(xs), xs))
1515
tgrad = Num.(substitute.(tgrad, Ref(reverse_rule)))
1616
get_tgrad(sys)[] = tgrad
@@ -102,10 +102,16 @@ function generate_function(
102102
p = map(x->time_varying_as_func(value(x), sys), ps)
103103
t = get_iv(sys)
104104

105+
if has_preface(sys) && (pre = preface(sys); pre !== nothing)
106+
pre = ex -> Let(pre, ex)
107+
else
108+
pre = ex -> ex
109+
end
110+
105111
if implicit_dae
106-
build_function(rhss, ddvs, u, p, t; kwargs...)
112+
build_function(rhss, ddvs, u, p, t; postprocess_fbody=pre, kwargs...)
107113
else
108-
build_function(rhss, u, p, t; kwargs...)
114+
build_function(rhss, u, p, t; postprocess_fbody=pre, kwargs...)
109115
end
110116
end
111117

@@ -409,13 +415,13 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
409415
=#
410416

411417
fsym = gensym(:f)
412-
_f = :($fsym = ModelingToolkit.ODEFunctionClosure($f_oop, $f_iip))
418+
_f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip))
413419
tgradsym = gensym(:tgrad)
414420
if tgrad
415421
tgrad_oop, tgrad_iip = generate_tgrad(sys, dvs, ps;
416422
simplify=simplify,
417423
expression=Val{true}, kwargs...)
418-
_tgrad = :($tgradsym = ModelingToolkit.ODEFunctionClosure($tgrad_oop, $tgrad_iip))
424+
_tgrad = :($tgradsym = $ODEFunctionClosure($tgrad_oop, $tgrad_iip))
419425
else
420426
_tgrad = :($tgradsym = nothing)
421427
end
@@ -425,7 +431,7 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
425431
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps;
426432
sparse=sparse, simplify=simplify,
427433
expression=Val{true}, kwargs...)
428-
_jac = :($jacsym = ModelingToolkit.ODEFunctionClosure($jac_oop, $jac_iip))
434+
_jac = :($jacsym = $ODEFunctionClosure($jac_oop, $jac_iip))
429435
else
430436
_jac = :($jacsym = nothing)
431437
end

0 commit comments

Comments
 (0)