Skip to content

Commit 8906854

Browse files
Merge pull request #3693 from AayushSabharwal/as/v10-concrete-getu
fix: fix major compile time regression due to `concrete_getu`
2 parents 66cc813 + 53fb05f commit 8906854

File tree

5 files changed

+51
-41
lines changed

5 files changed

+51
-41
lines changed

src/systems/codegen.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,8 @@ Generates a function that computes the observed value(s) `ts` in the system `sys
943943
- `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist.
944944
- `mkarray`: only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
945945
- `cse = true`: Whether to use Common Subexpression Elimination (CSE) to generate a more efficient function.
946+
- `wrap_delays = is_dde(sys)`: Whether to add an argument for the history function and use
947+
it to calculate all delayed variables.
946948
947949
## Returns
948950
@@ -981,7 +983,8 @@ function build_explicit_observed_function(sys, ts;
981983
op = Operator,
982984
throw = true,
983985
cse = true,
984-
mkarray = nothing)
986+
mkarray = nothing,
987+
wrap_delays = is_dde(sys))
985988
# TODO: cleanup
986989
is_tuple = ts isa Tuple
987990
if is_tuple
@@ -1068,14 +1071,15 @@ function build_explicit_observed_function(sys, ts;
10681071
p_end = length(dvs) + length(inputs) + length(ps)
10691072
fns = build_function_wrapper(
10701073
sys, ts, args...; p_start, p_end, filter_observed = obsfilter,
1071-
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse)
1074+
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse,
1075+
wrap_delays)
10721076
if fns isa Tuple
10731077
if expression
10741078
return return_inplace ? fns : fns[1]
10751079
end
10761080
oop, iip = eval_or_rgf.(fns; eval_expression, eval_module)
10771081
f = GeneratedFunctionWrapper{(
1078-
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
1082+
p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}(
10791083
oop, iip)
10801084
return return_inplace ? (f, f) : f
10811085
else
@@ -1084,7 +1088,7 @@ function build_explicit_observed_function(sys, ts;
10841088
end
10851089
f = eval_or_rgf(fns; eval_expression, eval_module)
10861090
f = GeneratedFunctionWrapper{(
1087-
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
1091+
p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}(
10881092
f, nothing)
10891093
return f
10901094
end

src/systems/problem_utils.jl

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -646,30 +646,40 @@ struct ReconstructInitializeprob{GP, GU}
646646
ugetter::GU
647647
end
648648

649+
"""
650+
$(TYPEDEF)
651+
652+
A wrapper over an observed function which allows calling it on a problem-like object.
653+
`TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if
654+
`false`).
655+
"""
656+
struct ObservedWrapper{TD, F}
657+
f::F
658+
end
659+
660+
ObservedWrapper{TD}(f::F) where {TD, F} = ObservedWrapper{TD, F}(f)
661+
662+
function (ow::ObservedWrapper{true})(prob)
663+
ow.f(state_values(prob), parameter_values(prob), current_time(prob))
664+
end
665+
666+
function (ow::ObservedWrapper{false})(prob)
667+
ow.f(state_values(prob), parameter_values(prob))
668+
end
669+
649670
"""
650671
$(TYPEDSIGNATURES)
651672
652673
Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
653-
function by splitting `syms` into contiguous buffers where the getter of each buffer
654-
is type-stable and constructing a function that calls and concatenates the results.
655-
"""
656-
function concrete_getu(indp, syms::AbstractVector)
657-
# a list of contiguous buffer
658-
split_syms = [Any[syms[1]]]
659-
# the type of the getter of the last buffer
660-
current = typeof(getu(indp, syms[1]))
661-
for sym in syms[2:end]
662-
getter = getu(indp, sym)
663-
if typeof(getter) != current
664-
# if types don't match, build a new buffer
665-
push!(split_syms, [])
666-
current = typeof(getter)
667-
end
668-
push!(split_syms[end], sym)
669-
end
670-
split_syms = Tuple(split_syms)
671-
# the getter is now type-stable, and we can vcat it to get the full buffer
672-
return Base.Fix1(reduce, vcat) getu(indp, split_syms)
674+
function.
675+
676+
Note that the getter ONLY works for problem-like objects, since it generates an observed
677+
function. It does NOT work for solutions.
678+
"""
679+
Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector)
680+
@nospecialize
681+
obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false)
682+
return ObservedWrapper{is_time_dependent(indp)}(obsfn)
673683
end
674684

675685
"""

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ Keyword arguments:
830830
`available_vars` will not be searched for in the observed equations.
831831
"""
832832
function observed_equations_used_by(sys::AbstractSystem, exprs;
833-
involved_vars = vars(exprs; op = Union{Shift, Differential}), obs = observed(sys), available_vars = [])
833+
involved_vars = vars(exprs; op = Union{Shift, Differential, Initial}), obs = observed(sys), available_vars = [])
834834
obsvars = getproperty.(obs, :lhs)
835835
graph = observed_dependency_graph(obs)
836836
if !(available_vars isa Set)

test/extensions/ad.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@ sol = solve(prob, Tsit5())
2727

2828
mtkparams = parameter_values(prob)
2929
new_p = rand(14)
30-
@test_broken begin
31-
gs = gradient(new_p) do new_p
32-
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
33-
new_prob = remake(prob, p = new_params)
34-
new_sol = solve(new_prob, Tsit5())
35-
sum(new_sol)
36-
end
30+
gs = gradient(new_p) do new_p
31+
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
32+
new_prob = remake(prob, p = new_params)
33+
new_sol = solve(new_prob, Tsit5())
34+
sum(new_sol)
3735
end
3836

3937
@testset "Issue#2997" begin

test/odesystem.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ using OrdinaryDiffEq, Sundials
66
using DiffEqBase, SparseArrays
77
using StaticArrays
88
using Test
9-
using SymbolicUtils: issym
9+
using SymbolicUtils.Code
10+
using SymbolicUtils: Sym, issym
1011
using ForwardDiff
1112
using ModelingToolkit: value
1213
using ModelingToolkit: t_nounits as t, D_nounits as D
14+
using Symbolics
1315
using Symbolics: unwrap
16+
using DiffEqBase: isinplace
1417

1518
# Define some variables
1619
@parameters σ ρ β
@@ -505,13 +508,6 @@ sys = complete(sys)
505508
@test_throws Any ODEFunction(sys)
506509

507510
@testset "Preface tests" begin
508-
using OrdinaryDiffEq
509-
using Symbolics
510-
using DiffEqBase: isinplace
511-
using ModelingToolkit
512-
using SymbolicUtils.Code
513-
using SymbolicUtils: Sym
514-
515511
c = [0]
516512
function f(c, du::AbstractVector{Float64}, u::AbstractVector{Float64}, p, t::Float64)
517513
c .= [c[1] + 1]
@@ -554,7 +550,9 @@ sys = complete(sys)
554550

555551
@named sys = System(eqs, t, us, ps; defaults = defs, preface = preface)
556552
sys = complete(sys)
557-
prob = ODEProblem(sys, [], (0.0, 1.0))
553+
# don't build initializeprob because it will use preface in other functions and
554+
# affect `c`
555+
prob = ODEProblem(sys, [], (0.0, 1.0); build_initializeprob = false)
558556
sol = solve(prob, Euler(); dt = 0.1)
559557

560558
@test c[1] == length(sol)

0 commit comments

Comments
 (0)