Skip to content

Commit 166606a

Browse files
authored
Merge pull request #2055 from SciML/myb/acc
Add accumulation
2 parents 04d2c6b + 034fc86 commit 166606a

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ end
162162

163163
export AbstractTimeDependentSystem, AbstractTimeIndependentSystem,
164164
AbstractMultivariateSystem
165-
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr, convert_system
165+
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr, convert_system,
166+
add_accumulations
166167
export DAEFunctionExpr, DAEProblemExpr
167168
export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
168169
export SystemStructure

src/systems/diffeqs/odesystem.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,21 @@ function convert_system(::Type{<:ODESystem}, sys, t; name = nameof(sys))
436436
return ODESystem(neweqs, t, newsts, parameters(sys); defaults = defs, name = name,
437437
checks = false)
438438
end
439+
440+
"""
441+
$(SIGNATURES)
442+
443+
Add accumulation variables for `vars`.
444+
"""
445+
function add_accumulations(sys::ODESystem, vars = states(sys))
446+
eqs = get_eqs(sys)
447+
accs = filter(x -> startswith(string(x), "accumulation_"), states(sys))
448+
if !isempty(accs)
449+
error("$accs variable names start with \"accumulation_\"")
450+
end
451+
avars = [rename(v, Symbol(:accumulation_, getname(v))) for v in vars]
452+
D = Differential(get_iv(sys))
453+
@set! sys.eqs = [eqs; Equation[D(a) ~ v for (a, v) in zip(avars, vars)]]
454+
@set! sys.states = [get_states(sys); avars]
455+
@set! sys.defaults = merge(get_defaults(sys), Dict(a => 0.0 for a in avars))
456+
end

test/odesystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,15 @@ D = Differential(t)
331331
@variables x(t) y(t) z(t)
332332
D = Differential(t)
333333
@named sys = ODESystem([D(x) ~ y, 0 ~ x + z, 0 ~ x - y], t, [z, y, x], [])
334+
asys = add_accumulations(sys)
335+
@variables accumulation_x(t) accumulation_y(t) accumulation_z(t)
336+
eqs = [0 ~ x + z
337+
0 ~ x - y
338+
D(accumulation_x) ~ x
339+
D(accumulation_y) ~ y
340+
D(accumulation_z) ~ z
341+
D(x) ~ y]
342+
@test sort(equations(asys), by = string) == eqs
334343
sys2 = ode_order_lowering(sys)
335344
M = ModelingToolkit.calculate_massmatrix(sys2)
336345
@test M == Diagonal([1, 0, 0])

0 commit comments

Comments
 (0)