diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index f8cf12f0f7..e7ae694625 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -550,7 +550,7 @@ end flatten(sys::AbstractSystem, args...) = sys -function equations(sys::ModelingToolkit.AbstractSystem) +function equations(sys::AbstractSystem) eqs = get_eqs(sys) systems = get_systems(sys) if isempty(systems) @@ -564,7 +564,7 @@ function equations(sys::ModelingToolkit.AbstractSystem) end end -function preface(sys::ModelingToolkit.AbstractSystem) +function preface(sys::AbstractSystem) has_preface(sys) || return nothing pre = get_preface(sys) systems = get_systems(sys) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 1bf06f9c7a..61654b0d5a 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -437,6 +437,13 @@ function convert_system(::Type{<:ODESystem}, sys, t; name = nameof(sys)) checks = false) end +function Symbolics.substitute(sys::ODESystem, rules::Union{Vector{<:Pair}, Dict}) + rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), + collect(rules))) + eqs = fast_substitute(equations(sys), rules) + ODESystem(eqs, get_iv(sys); name = nameof(sys)) +end + """ $(SIGNATURES) diff --git a/test/odesystem.jl b/test/odesystem.jl index ef1b4433e4..a43df6a42a 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -12,6 +12,7 @@ using ModelingToolkit: value @constants κ = 1 @variables x(t) y(t) z(t) D = Differential(t) +@parameters k # Define a differential equation eqs = [D(x) ~ σ * (y - x), @@ -20,6 +21,12 @@ eqs = [D(x) ~ σ * (y - x), ModelingToolkit.toexpr.(eqs)[1] @named de = ODESystem(eqs; defaults = Dict(x => 1)) +subed = substitute(de, [σ => k]) +@test isequal(sort(parameters(subed), by = string), [k, β, ρ]) +@test isequal(equations(subed), + [D(x) ~ k * (y - x) + D(y) ~ (ρ - z) * x - y + D(z) ~ x * y - β * κ * z]) @named des[1:3] = ODESystem(eqs) @test length(unique(x -> ModelingToolkit.get_tag(x), des)) == 1 @test eval(toexpr(de)) == de