Skip to content

Commit a779f4b

Browse files
refactor!: require systems to be completed before creating an XFunction/XFunctionExpr
1 parent 9b41617 commit a779f4b

File tree

10 files changed

+46
-5
lines changed

10 files changed

+46
-5
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = u
308308
analytic = nothing,
309309
split_idxs = nothing,
310310
kwargs...) where {iip, specialize}
311+
if !iscomplete(sys)
312+
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
313+
end
311314
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
312315
expression_module = eval_module, checkbounds = checkbounds,
313316
kwargs...)
@@ -504,6 +507,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
504507
eval_module = @__MODULE__,
505508
checkbounds = false,
506509
kwargs...) where {iip}
510+
if !iscomplete(sys)
511+
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
512+
end
507513
f_gen = generate_function(sys, dvs, ps; implicit_dae = true,
508514
expression = Val{eval_expression},
509515
expression_module = eval_module, checkbounds = checkbounds,
@@ -579,6 +585,9 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
579585
eval_module = @__MODULE__,
580586
checkbounds = false,
581587
kwargs...) where {iip}
588+
if !iscomplete(sys)
589+
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DDEFunction`")
590+
end
582591
f_gen = generate_function(sys, dvs, ps; isdde = true,
583592
expression = Val{true},
584593
expression_module = eval_module, checkbounds = checkbounds,
@@ -603,6 +612,9 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
603612
eval_module = @__MODULE__,
604613
checkbounds = false,
605614
kwargs...) where {iip}
615+
if !iscomplete(sys)
616+
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `SDDEFunction`")
617+
end
606618
f_gen = generate_function(sys, dvs, ps; isdde = true,
607619
expression = Val{true},
608620
expression_module = eval_module, checkbounds = checkbounds,
@@ -656,6 +668,9 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
656668
sparsity = false,
657669
observedfun_exp = nothing,
658670
kwargs...) where {iip}
671+
if !iscomplete(sys)
672+
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunctionExpr`")
673+
end
659674
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
660675

661676
dict = Dict()
@@ -830,6 +845,9 @@ function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
830845
linenumbers = false,
831846
sparse = false, simplify = false,
832847
kwargs...) where {iip}
848+
if !iscomplete(sys)
849+
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DAEFunctionExpr`")
850+
end
833851
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true},
834852
implicit_dae = true, kwargs...)
835853
fsym = gensym(:f)

src/systems/diffeqs/sdesystem.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
393393
jac = false, Wfact = false, eval_expression = true,
394394
checkbounds = false,
395395
kwargs...) where {iip}
396+
if !iscomplete(sys)
397+
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
398+
end
396399
dvs = scalarize.(dvs)
397400
ps = scalarize.(ps)
398401

@@ -515,6 +518,9 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
515518
jac = false, Wfact = false,
516519
sparse = false, linenumbers = false,
517520
kwargs...) where {iip}
521+
if !iscomplete(sys)
522+
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunctionExpr`")
523+
end
518524
idx = iip ? 2 : 1
519525
f = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)[idx]
520526
g = generate_diffusion_function(sys, dvs, ps; expression = Val{true}, kwargs...)[idx]

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
234234
eval_expression = true,
235235
sparse = false, simplify = false,
236236
kwargs...) where {iip}
237+
if !iscomplete(sys)
238+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunction`")
239+
end
237240
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, kwargs...)
238241
f_oop, f_iip = eval_expression ?
239242
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in f_gen) : f_gen
@@ -296,6 +299,9 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
296299
linenumbers = false,
297300
sparse = false, simplify = false,
298301
kwargs...) where {iip}
302+
if !iscomplete(sys)
303+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunctionExpr`")
304+
end
299305
idx = iip ? 2 : 1
300306
f = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)[idx]
301307

test/distributed.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ addprocs(2)
1414
D(z) ~ x * y - β * z]
1515

1616
@everywhere @named de = ODESystem(eqs)
17+
@everywhere de = complete(de)
1718
@everywhere ode_func = ODEFunction(de, [x, y, z], [σ, ρ, β])
1819

1920
@everywhere u0 = [19.0, 20.0, 50.0]

test/function_registration.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ end
1818

1919
eq = Dt(u) ~ do_something(x) + MyModule.do_something(x)
2020
@named sys = ODESystem([eq], t, [u], [x])
21+
sys = complete(sys)
2122
fun = ODEFunction(sys)
2223

2324
u0 = 5.0
@@ -40,6 +41,7 @@ end
4041

4142
eq = Dt(u) ~ do_something_2(x) + MyNestedModule.do_something_2(x)
4243
@named sys = ODESystem([eq], t, [u], [x])
44+
sys = complete(sys)
4345
fun = ODEFunction(sys)
4446

4547
u0 = 3.0
@@ -61,6 +63,7 @@ end
6163

6264
eq = Dt(u) ~ do_something_3(x) + (@__MODULE__).do_something_3(x)
6365
@named sys = ODESystem([eq], t, [u], [x])
66+
sys = complete(sys)
6467
fun = ODEFunction(sys)
6568

6669
u0 = 7.0
@@ -99,6 +102,7 @@ function build_ode()
99102
Dt = Differential(t)
100103
eq = Dt(u) ~ do_something_4(x) + (@__MODULE__).do_something_4(x)
101104
@named sys = ODESystem([eq], t, [u], [x])
105+
sys = complete(sys)
102106
fun = ODEFunction(sys, eval_expression = false)
103107
end
104108
function run_test()

test/labelledarrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ eqs = [D(x) ~ σ * (y - x),
1313
D(z) ~ x * y - β * z]
1414

1515
@named de = ODESystem(eqs)
16+
de = complete(de)
1617
ff = ODEFunction(de, [x, y, z], [σ, ρ, β], jac = true)
1718

1819
a = @SVector [1.0, 2.0, 3.0]

test/mass_matrix.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ eqs = [D(y[1]) ~ -k[1] * y[1] + k[3] * y[2] * y[3],
99
0 ~ y[1] + y[2] + y[3] - 1]
1010

1111
@named sys = ODESystem(eqs, t, y, k)
12+
sys = complete(sys)
1213
@test_throws ArgumentError ODESystem(eqs, y[1])
1314
M = calculate_massmatrix(sys)
1415
@test M == [1 0 0

test/odesystem.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jac_expr = generate_jacobian(de)
5151
jac = calculate_jacobian(de)
5252
jacfun = eval(jac_expr[2])
5353

54+
de = complete(de)
5455
for f in [
5556
ODEFunction(de, [x, y, z], [σ, ρ, β], tgrad = true, jac = true),
5657
eval(ODEFunctionExpr(de, [x, y, z], [σ, ρ, β], tgrad = true, jac = true)),
@@ -167,7 +168,7 @@ lowered_eqs = [D(uˍtt) ~ 2uˍtt + uˍt + xˍt + 1
167168

168169
test_diffeq_inference("first-order transform", de1, t, [uˍtt, xˍt, uˍt, u, x], [])
169170
du = zeros(5)
170-
ODEFunction(de1, [uˍtt, xˍt, uˍt, u, x], [])(du, ones(5), nothing, 0.1)
171+
ODEFunction(complete(de1), [uˍtt, xˍt, uˍt, u, x], [])(du, ones(5), nothing, 0.1)
171172
@test du == [5.0, 3.0, 1.0, 1.0, 1.0]
172173

173174
# Internal calculations
@@ -182,7 +183,7 @@ jac = calculate_jacobian(de)
182183
@test ModelingToolkit.jacobian_sparsity(de).colptr == sparse(jac).colptr
183184
@test ModelingToolkit.jacobian_sparsity(de).rowval == sparse(jac).rowval
184185

185-
f = ODEFunction(de, [x, y, z], [σ, ρ, β])
186+
f = ODEFunction(complete(de), [x, y, z], [σ, ρ, β])
186187

187188
D = Differential(t)
188189
@parameters A B C
@@ -208,7 +209,7 @@ function lotka(u, p, t)
208209
end
209210

210211
prob = ODEProblem(ODEFunction{false}(lotka), [1.0, 1.0], (0.0, 1.0), [1.5, 1.0, 3.0, 1.0])
211-
de = modelingtoolkitize(prob)
212+
de = complete(modelingtoolkitize(prob))
212213
ODEFunction(de)(similar(prob.u0), prob.u0, prob.p, 0.1)
213214

214215
function lotka(du, u, p, t)
@@ -220,7 +221,7 @@ end
220221

221222
prob = ODEProblem(lotka, [1.0, 1.0], (0.0, 1.0), [1.5, 1.0, 3.0, 1.0])
222223

223-
de = modelingtoolkitize(prob)
224+
de = complete(modelingtoolkitize(prob))
224225
ODEFunction(de)(similar(prob.u0), prob.u0, prob.p, 0.1)
225226

226227
# automatic unknown detection for DAEs
@@ -579,11 +580,13 @@ eqs = [
579580
]
580581

581582
@named sys = ODESystem(eqs, t, [x, y, z], [α, β])
583+
sys = complete(sys)
582584
@test_throws Any ODEFunction(sys)
583585

584586
eqs = copy(eqs)
585587
eqs[end] = D(D(z)) ~ α * x - β * y
586588
@named sys = ODESystem(eqs, t, [x, y, z], [α, β])
589+
sys = complete(sys)
587590
@test_throws Any ODEFunction(sys)
588591

589592
@testset "Preface tests" begin

test/precompile_test/ODEPrecompileTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ function system(; kwargs...)
1313
D(z) ~ x * y - β * z]
1414

1515
@named de = ODESystem(eqs)
16+
de = complete(de)
1617
return ODEFunction(de, [x, y, z], [σ, ρ, β]; kwargs...)
1718
end
1819

test/structural_transformation/index_reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ idx1_pendulum = [D(x) ~ w,
5959
# substitute the rhs
6060
0 ~ 2x * (T * x) + 2 * xˍt * xˍt + 2y * (T * y - g) + 2 * yˍt * yˍt]
6161
@named idx1_pendulum = ODESystem(idx1_pendulum, t, [x, y, w, z, xˍt, yˍt, T], [L, g])
62-
first_order_idx1_pendulum = ode_order_lowering(idx1_pendulum)
62+
first_order_idx1_pendulum = complete(ode_order_lowering(idx1_pendulum))
6363

6464
using OrdinaryDiffEq
6565
using LinearAlgebra

0 commit comments

Comments
 (0)