Skip to content

Commit db720c1

Browse files
Merge pull request #3724 from AayushSabharwal/as/redundant-codegen
fix: do not unwrap initials in `initializeprobpmap`
2 parents b55646e + 3518b39 commit db720c1

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

src/systems/problem_utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,8 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
736736
# Keyword Arguments
737737
- `initials`: Whether to include the `Initial` parameters of `dstsys` among the values
738738
to be transferred.
739+
- `unwrap_initials`: Whether initials in `dstsys` corresponding to unknowns in `srcsys` are
740+
unwrapped.
739741
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
740742
"""
741743
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
@@ -758,7 +760,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
758760
end
759761
initials_getter = if initials && !isempty(syms[2])
760762
initsyms = Vector{Any}(syms[2])
761-
allsyms = Set(all_symbols(srcsys))
763+
allsyms = Set(variable_symbols(srcsys))
762764
if unwrap_initials
763765
for i in eachindex(initsyms)
764766
sym = initsyms[i]

test/code_generation.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,33 @@ end
110110
@test val[] == 2
111111
end
112112
end
113+
114+
@testset "Do not codegen redundant expressions" begin
115+
@variables v1(t) = 1
116+
@variables v2(t) [guess = 0]
117+
118+
mutable struct Data
119+
count::Int
120+
end
121+
function update!(d::Data, t)
122+
d.count += 1 # Count the number of times the data gets updated.
123+
end
124+
function (d::Data)(t)
125+
update!(d, t)
126+
rand(1:10)
127+
end
128+
129+
@parameters (d1::Data)(..) = Data(0)
130+
@parameters (d2::Data)(..) = Data(0)
131+
132+
eqs = [
133+
D(v1) ~ d1(t),
134+
v2 ~ d2(t) # Some of the data parameters are not actually needed to solve the system.
135+
]
136+
137+
@mtkbuild sys = System(eqs, t)
138+
prob = ODEProblem(sys, [], (0.0, 1.0))
139+
sol = solve(prob, Tsit5())
140+
141+
@test sol.ps[d2].count == 0
142+
end

0 commit comments

Comments
 (0)