Skip to content

Commit 064efaf

Browse files
feat: allow passing wrap_delays to build_explicit_observed_function
1 parent 66cc813 commit 064efaf

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/systems/codegen.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,8 @@ Generates a function that computes the observed value(s) `ts` in the system `sys
943943
- `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist.
944944
- `mkarray`: only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
945945
- `cse = true`: Whether to use Common Subexpression Elimination (CSE) to generate a more efficient function.
946+
- `wrap_delays = is_dde(sys)`: Whether to add an argument for the history function and use
947+
it to calculate all delayed variables.
946948
947949
## Returns
948950
@@ -981,7 +983,8 @@ function build_explicit_observed_function(sys, ts;
981983
op = Operator,
982984
throw = true,
983985
cse = true,
984-
mkarray = nothing)
986+
mkarray = nothing,
987+
wrap_delays = is_dde(sys))
985988
# TODO: cleanup
986989
is_tuple = ts isa Tuple
987990
if is_tuple
@@ -1068,14 +1071,15 @@ function build_explicit_observed_function(sys, ts;
10681071
p_end = length(dvs) + length(inputs) + length(ps)
10691072
fns = build_function_wrapper(
10701073
sys, ts, args...; p_start, p_end, filter_observed = obsfilter,
1071-
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse)
1074+
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse,
1075+
wrap_delays)
10721076
if fns isa Tuple
10731077
if expression
10741078
return return_inplace ? fns : fns[1]
10751079
end
10761080
oop, iip = eval_or_rgf.(fns; eval_expression, eval_module)
10771081
f = GeneratedFunctionWrapper{(
1078-
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
1082+
p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}(
10791083
oop, iip)
10801084
return return_inplace ? (f, f) : f
10811085
else
@@ -1084,7 +1088,7 @@ function build_explicit_observed_function(sys, ts;
10841088
end
10851089
f = eval_or_rgf(fns; eval_expression, eval_module)
10861090
f = GeneratedFunctionWrapper{(
1087-
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
1091+
p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}(
10881092
f, nothing)
10891093
return f
10901094
end

0 commit comments

Comments
 (0)