Skip to content

Commit 38ac571

Browse files
Merge pull request #1059 from cadojo/control_input_1041
Control parameter specification for `AbstractODESystem`, `DiscreteSystem`
2 parents d7b8f96 + 47c0039 commit 38ac571

File tree

11 files changed

+123
-26
lines changed

11 files changed

+123
-26
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ export structural_simplify
166166
export DiscreteSystem, DiscreteProblem
167167

168168
export calculate_jacobian, generate_jacobian, generate_function
169+
export calculate_control_jacobian, generate_control_jacobian
169170
export calculate_tgrad, generate_tgrad
170171
export calculate_gradient, generate_gradient
171172
export calculate_factorized_W, generate_factorized_W

src/systems/abstractsystem.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ call will be cached in the system object.
3434
"""
3535
function calculate_jacobian end
3636

37+
"""
38+
```julia
39+
calculate_control_jacobian(sys::AbstractSystem)
40+
```
41+
42+
Calculate the jacobian matrix of a system with respect to the system's controls.
43+
44+
Returns a matrix of [`Num`](@ref) instances. The result from the first
45+
call will be cached in the system object.
46+
"""
47+
function calculate_control_jacobian end
48+
3749
"""
3850
```julia
3951
calculate_factorized_W(sys::AbstractSystem)
@@ -140,10 +152,12 @@ for prop in [
140152
:iv
141153
:states
142154
:ps
155+
:ctrls
143156
:defaults
144157
:observed
145158
:tgrad
146159
:jac
160+
:ctrl_jac
147161
:Wfact
148162
:Wfact_t
149163
:systems
@@ -301,6 +315,7 @@ end
301315

302316
namespace_variables(sys::AbstractSystem) = states(sys, states(sys))
303317
namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
318+
namespace_controls(sys::AbstractSystem) = controls(sys, controls(sys))
304319

305320
function namespace_defaults(sys)
306321
defs = defaults(sys)
@@ -344,13 +359,21 @@ function states(sys::AbstractSystem)
344359
systems = get_systems(sys)
345360
unique(isempty(systems) ?
346361
sts :
347-
[sts;reduce(vcat,namespace_variables.(systems))])
362+
[sts; reduce(vcat,namespace_variables.(systems))])
348363
end
364+
349365
function parameters(sys::AbstractSystem)
350366
ps = get_ps(sys)
351367
systems = get_systems(sys)
352-
isempty(systems) ? ps : [ps;reduce(vcat,namespace_parameters.(systems))]
368+
isempty(systems) ? ps : [ps; reduce(vcat,namespace_parameters.(systems))]
353369
end
370+
371+
function controls(sys::AbstractSystem)
372+
ctrls = get_ctrls(sys)
373+
systems = get_systems(sys)
374+
isempty(systems) ? ctrls : [ctrls; reduce(vcat,namespace_controls.(systems))]
375+
end
376+
354377
function observed(sys::AbstractSystem)
355378
iv = independent_variable(sys)
356379
obs = get_observed(sys)

src/systems/control/controlsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
abstract type AbstractControlSystem <: AbstractSystem end
22

3-
function namespace_controls(sys::AbstractSystem)
3+
function namespace_controls(sys::AbstractControlSystem)
44
[rename(x,renamespace(nameof(sys),nameof(x))) for x in controls(sys)]
55
end
66

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,28 @@ function calculate_jacobian(sys::AbstractODESystem;
3838
return jac
3939
end
4040

41+
function calculate_control_jacobian(sys::AbstractODESystem;
42+
sparse=false, simplify=false)
43+
cache = get_ctrl_jac(sys)[]
44+
if cache isa Tuple && cache[2] == (sparse, simplify)
45+
return cache[1]
46+
end
47+
48+
rhs = [eq.rhs for eq equations(sys)]
49+
50+
iv = get_iv(sys)
51+
ctrls = controls(sys)
52+
53+
if sparse
54+
jac = sparsejacobian(rhs, ctrls, simplify=simplify)
55+
else
56+
jac = jacobian(rhs, ctrls, simplify=simplify)
57+
end
58+
59+
get_ctrl_jac(sys)[] = jac, (sparse, simplify) # cache Jacobian
60+
return jac
61+
end
62+
4163
function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
4264
simplify=false, kwargs...)
4365
tgrad = calculate_tgrad(sys,simplify=simplify)
@@ -50,6 +72,12 @@ function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = param
5072
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
5173
end
5274

75+
function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
76+
simplify=false, sparse = false, kwargs...)
77+
jac = calculate_control_jacobian(sys;simplify=simplify,sparse=sparse)
78+
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
79+
end
80+
5381
@noinline function throw_invalid_derivative(dervar, eq)
5482
msg = "The derivative variable must be isolated to the left-hand " *
5583
"side of the equation like `$dervar ~ ...`.\n Got $eq."

src/systems/diffeqs/odesystem.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ struct ODESystem <: AbstractODESystem
3131
states::Vector
3232
"""Parameter variables. Must not contain the independent variable."""
3333
ps::Vector
34+
"""Control parameters (some subset of `ps`)."""
35+
ctrls::Vector
36+
"""Observed states."""
3437
observed::Vector{Equation}
3538
"""
3639
Time-derivative matrix. Note: this field will not be defined until
@@ -43,6 +46,11 @@ struct ODESystem <: AbstractODESystem
4346
"""
4447
jac::RefValue{Any}
4548
"""
49+
Control Jacobian matrix. Note: this field will not be defined until
50+
[`calculate_control_jacobian`](@ref) is called on the system.
51+
"""
52+
ctrl_jac::RefValue{Any}
53+
"""
4654
`Wfact` matrix. Note: this field will not be defined until
4755
[`generate_factorized_W`](@ref) is called on the system.
4856
"""
@@ -74,16 +82,17 @@ struct ODESystem <: AbstractODESystem
7482
"""
7583
connection_type::Any
7684

77-
function ODESystem(deqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
85+
function ODESystem(deqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
7886
check_variables(dvs,iv)
7987
check_parameters(ps,iv)
8088
check_equations(deqs,iv)
81-
new(deqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
89+
new(deqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
8290
end
8391
end
8492

8593
function ODESystem(
8694
deqs::AbstractVector{<:Equation}, iv, dvs, ps;
95+
controls = Num[],
8796
observed = Num[],
8897
systems = ODESystem[],
8998
name=gensym(:ODESystem),
@@ -92,9 +101,13 @@ function ODESystem(
92101
defaults=_merge(Dict(default_u0), Dict(default_p)),
93102
connection_type=nothing,
94103
)
104+
105+
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
106+
95107
iv′ = value(scalarize(iv))
96108
dvs′ = value.(scalarize(dvs))
97109
ps′ = value.(scalarize(ps))
110+
ctrl′ = value.(scalarize(controls))
98111

99112
if !(isempty(default_u0) && isempty(default_p))
100113
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ODESystem, force=true)
@@ -104,13 +117,14 @@ function ODESystem(
104117

105118
tgrad = RefValue(Vector{Num}(undef, 0))
106119
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
120+
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
107121
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
108122
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
109123
sysnames = nameof.(systems)
110124
if length(unique(sysnames)) != length(sysnames)
111125
throw(ArgumentError("System names must be unique."))
112126
end
113-
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
127+
ODESystem(deqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
114128
end
115129

116130
vars(x::Sym) = Set([x])
@@ -349,4 +363,4 @@ function convert_system(::Type{<:ODESystem}, sys, t; name=nameof(sys))
349363
neweqs = map(sub, equations(sys))
350364
defs = Dict(sub(k) => sub(v) for (k, v) in defaults(sys))
351365
return ODESystem(neweqs, t, newsts, parameters(sys); defaults=defs, name=name)
352-
end
366+
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ struct SDESystem <: AbstractODESystem
3737
states::Vector
3838
"""Parameter variables. Must not contain the independent variable."""
3939
ps::Vector
40-
observed::Vector
40+
"""Control parameters (some subset of `ps`)."""
41+
ctrls::Vector
42+
"""Observed states."""
43+
observed::Vector{Equation}
4144
"""
4245
Time-derivative matrix. Note: this field will not be defined until
4346
[`calculate_tgrad`](@ref) is called on the system.
@@ -49,6 +52,11 @@ struct SDESystem <: AbstractODESystem
4952
"""
5053
jac::RefValue
5154
"""
55+
Control Jacobian matrix. Note: this field will not be defined until
56+
[`calculate_control_jacobian`](@ref) is called on the system.
57+
"""
58+
ctrl_jac::RefValue{Any}
59+
"""
5260
`Wfact` matrix. Note: this field will not be defined until
5361
[`generate_factorized_W`](@ref) is called on the system.
5462
"""
@@ -76,16 +84,17 @@ struct SDESystem <: AbstractODESystem
7684
"""
7785
connection_type::Any
7886

79-
function SDESystem(deqs, neqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
87+
function SDESystem(deqs, neqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
8088
check_variables(dvs,iv)
8189
check_parameters(ps,iv)
8290
check_equations(deqs,iv)
83-
new(deqs, neqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
91+
new(deqs, neqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
8492
end
8593
end
8694

8795
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
88-
observed = [],
96+
controls = Num[],
97+
observed = Num[],
8998
systems = SDESystem[],
9099
default_u0=Dict(),
91100
default_p=Dict(),
@@ -96,6 +105,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
96105
iv′ = value(iv)
97106
dvs′ = value.(dvs)
98107
ps′ = value.(ps)
108+
ctrl′ = value.(controls)
109+
99110
sysnames = nameof.(systems)
100111
if length(unique(sysnames)) != length(sysnames)
101112
throw(ArgumentError("System names must be unique."))
@@ -108,9 +119,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
108119

109120
tgrad = RefValue(Vector{Num}(undef, 0))
110121
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
122+
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
111123
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
112124
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
113-
SDESystem(deqs, neqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
125+
SDESystem(deqs, neqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
114126
end
115127

116128
function generate_diffusion_function(sys::SDESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
@@ -157,10 +169,6 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
157169
SDESystem(deqs,get_noiseeqs(sys),get_iv(sys),states(sys),parameters(sys))
158170
end
159171

160-
161-
162-
163-
164172
"""
165173
```julia
166174
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;

src/systems/discrete_system/discrete_system.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ struct DiscreteSystem <: AbstractSystem
3030
states::Vector
3131
"""Parameter variables. Must not contain the independent variable."""
3232
ps::Vector
33+
"""Control parameters (some subset of `ps`)."""
34+
ctrls::Vector
35+
"""Observed states."""
3336
observed::Vector{Equation}
3437
"""
3538
Name: the name of the system
@@ -49,10 +52,10 @@ struct DiscreteSystem <: AbstractSystem
4952
in `DiscreteSystem`.
5053
"""
5154
default_p::Dict
52-
function DiscreteSystem(discreteEqs, iv, dvs, ps, observed, name, systems, default_u0, default_p)
53-
check_variables(dvs, iv)
54-
check_parameters(ps, iv)
55-
new(discreteEqs, iv, dvs, ps, observed, name, systems, default_u0, default_p)
55+
function DiscreteSystem(discreteEqs, iv, dvs, ps, ctrls, observed, name, systems, default_u0, default_p)
56+
check_variables(dvs,iv)
57+
check_parameters(ps,iv)
58+
new(discreteEqs, iv, dvs, ps, ctrls, observed, name, systems, default_u0, default_p)
5659
end
5760
end
5861

@@ -63,6 +66,7 @@ Constructs a DiscreteSystem.
6366
"""
6467
function DiscreteSystem(
6568
discreteEqs::AbstractVector{<:Equation}, iv, dvs, ps;
69+
controls = Num[],
6670
observed = Num[],
6771
systems = DiscreteSystem[],
6872
name=gensym(:DiscreteSystem),
@@ -72,6 +76,7 @@ function DiscreteSystem(
7276
iv′ = value(iv)
7377
dvs′ = value.(dvs)
7478
ps′ = value.(ps)
79+
ctrl′ = value.(controls)
7580

7681
default_u0 isa Dict || (default_u0 = Dict(default_u0))
7782
default_p isa Dict || (default_p = Dict(default_p))
@@ -82,7 +87,7 @@ function DiscreteSystem(
8287
if length(unique(sysnames)) != length(sysnames)
8388
throw(ArgumentError("System names must be unique."))
8489
end
85-
DiscreteSystem(discreteEqs, iv′, dvs′, ps′, observed, name, systems, default_u0, default_p)
90+
DiscreteSystem(discreteEqs, iv′, dvs′, ps′, ctrl′, observed, name, systems, default_u0, default_p)
8691
end
8792

8893
"""

test/direct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
@variables a,b
7676
X = [a,b]
7777

78-
spoly(x) = simplify(x, polynorm=true)
78+
spoly(x) = simplify(x, expand=true)
7979
rr = rosenbrock(X)
8080

8181
reference_hes = ModelingToolkit.hessian(rr, X)

test/discretesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ eqs = [next_S ~ S-infection,
2222
next_R ~ R+recovery]
2323

2424
# System
25-
sys = DiscreteSystem(eqs,t,[S,I,R],[c,nsteps,δt,β,γ])
25+
sys = DiscreteSystem(eqs,t,[S,I,R],[c,nsteps,δt,β,γ]; controls = [β, γ])
2626

2727
# Problem
2828
u0 = [S => 990.0, I => 10.0, R => 0.0]
@@ -54,4 +54,4 @@ p = [0.05,10.0,0.25,0.1];
5454
prob_map = DiscreteProblem(sir_map!,u0,tspan,p);
5555
sol_map2 = solve(prob_map,FunctionMap());
5656

57-
@test Array(sol_map) Array(sol_map2)
57+
@test Array(sol_map) Array(sol_map2)

test/odesystem.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,4 +361,22 @@ D =Differential(t)
361361
eqs = [D(x1) ~ -x1]
362362
sys = ODESystem(eqs,t,[x1,x2],[])
363363
@test_throws ArgumentError ODEProblem(sys, [1.0,1.0], (0.0,1.0))
364-
prob = ODEProblem(sys, [1.0,1.0], (0.0,1.0), check_length=false)
364+
prob = ODEProblem(sys, [1.0,1.0], (0.0,1.0), check_length=false)
365+
366+
# check inputs
367+
let
368+
@parameters t f k d
369+
@variables x(t) (t)
370+
δ = Differential(t)
371+
372+
eqs = [δ(x) ~ ẋ, δ(ẋ) ~ f - k*x - d*ẋ]
373+
sys = ODESystem(eqs, t, [x, ẋ], [f, d, k]; controls = [f])
374+
375+
calculate_control_jacobian(sys)
376+
377+
@test isequal(
378+
calculate_control_jacobian(sys),
379+
reshape(Num[0,1], 2, 1)
380+
)
381+
382+
end

test/reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ ref_eq = [
2727
@variables x(t) y(t) z(t) a(t) u(t) F(t)
2828
D = Differential(t)
2929

30-
test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true))
30+
test_equal(a, b) = @test isequal(simplify(a, expand=true), simplify(b, expand=true))
3131

3232
eqs = [
3333
D(x) ~ σ*(y-x)

0 commit comments

Comments
 (0)