diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 53e831fca4..f8a48a68d5 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -162,7 +162,8 @@ end export AbstractTimeDependentSystem, AbstractTimeIndependentSystem, AbstractMultivariateSystem -export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr, convert_system +export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr, convert_system, + add_accumulations export DAEFunctionExpr, DAEProblemExpr export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr export SystemStructure diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 2d69e91df2..1bf06f9c7a 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -436,3 +436,21 @@ function convert_system(::Type{<:ODESystem}, sys, t; name = nameof(sys)) return ODESystem(neweqs, t, newsts, parameters(sys); defaults = defs, name = name, checks = false) end + +""" +$(SIGNATURES) + +Add accumulation variables for `vars`. +""" +function add_accumulations(sys::ODESystem, vars = states(sys)) + eqs = get_eqs(sys) + accs = filter(x -> startswith(string(x), "accumulation_"), states(sys)) + if !isempty(accs) + error("$accs variable names start with \"accumulation_\"") + end + avars = [rename(v, Symbol(:accumulation_, getname(v))) for v in vars] + D = Differential(get_iv(sys)) + @set! sys.eqs = [eqs; Equation[D(a) ~ v for (a, v) in zip(avars, vars)]] + @set! sys.states = [get_states(sys); avars] + @set! sys.defaults = merge(get_defaults(sys), Dict(a => 0.0 for a in avars)) +end diff --git a/test/odesystem.jl b/test/odesystem.jl index bed24dffff..ef1b4433e4 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -331,6 +331,15 @@ D = Differential(t) @variables x(t) y(t) z(t) D = Differential(t) @named sys = ODESystem([D(x) ~ y, 0 ~ x + z, 0 ~ x - y], t, [z, y, x], []) +asys = add_accumulations(sys) +@variables accumulation_x(t) accumulation_y(t) accumulation_z(t) +eqs = [0 ~ x + z + 0 ~ x - y + D(accumulation_x) ~ x + D(accumulation_y) ~ y + D(accumulation_z) ~ z + D(x) ~ y] +@test sort(equations(asys), by = string) == eqs sys2 = ode_order_lowering(sys) M = ModelingToolkit.calculate_massmatrix(sys2) @test M == Diagonal([1, 0, 0])