Skip to content

Commit 1d2c519

Browse files
feat: handle case of indexed array variable in SII impl for IndexCache
1 parent 9611e18 commit 1d2c519

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

src/systems/index_cache.jl

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -288,30 +288,25 @@ function IndexCache(sys::AbstractSystem)
288288
end
289289

290290
function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym)
291-
if sym isa Symbol
292-
sym = get(ic.symbol_to_variable, sym, nothing)
293-
sym === nothing && return false
294-
end
295-
return check_index_map(ic.unknown_idx, sym) !== nothing
291+
variable_index(ic, sym) !== nothing
296292
end
297293

298294
function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym)
299295
if sym isa Symbol
300296
sym = get(ic.symbol_to_variable, sym, nothing)
301297
sym === nothing && return nothing
302298
end
303-
return check_index_map(ic.unknown_idx, sym)
299+
idx = check_index_map(ic.unknown_idx, sym)
300+
idx === nothing || return idx
301+
iscall(sym) && operation(sym) == getindex || return nothing
302+
args = arguments(sym)
303+
idx = variable_index(ic, args[1])
304+
idx === nothing && return nothing
305+
ParameterIndex(idx.portion, (idx.idx..., args[2:end]...), idx.validate_size)
304306
end
305307

306308
function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
307-
if sym isa Symbol
308-
sym = get(ic.symbol_to_variable, sym, nothing)
309-
sym === nothing && return false
310-
end
311-
return check_index_map(ic.tunable_idx, sym) !== nothing ||
312-
check_index_map(ic.discrete_idx, sym) !== nothing ||
313-
check_index_map(ic.constant_idx, sym) !== nothing ||
314-
check_index_map(ic.nonnumeric_idx, sym) !== nothing
309+
parameter_index(ic, sym) !== nothing
315310
end
316311

317312
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
@@ -331,17 +326,21 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
331326
ParameterIndex(SciMLStructures.Constants(), idx, validate_size)
332327
elseif (idx = check_index_map(ic.nonnumeric_idx, sym)) !== nothing
333328
ParameterIndex(NONNUMERIC_PORTION, idx, validate_size)
334-
else
335-
nothing
329+
elseif iscall(sym) && operation(sym) == getindex
330+
args = arguments(sym)
331+
pidx = parameter_index(ic, args[1])
332+
pidx === nothing && return nothing
333+
if pidx.portion == SciMLStructures.Tunable()
334+
ParameterIndex(pidx.portion, reshape(pidx.idx, size(args[1]))[args[2:end]...],
335+
pidx.validate_size)
336+
else
337+
ParameterIndex(pidx.portion, (pidx.idx..., args[2:end]...), pidx.validate_size)
338+
end
336339
end
337340
end
338341

339342
function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym)
340-
if sym isa Symbol
341-
sym = get(ic.symbol_to_variable, sym, nothing)
342-
sym === nothing && return false
343-
end
344-
return check_index_map(ic.discrete_idx, sym) !== nothing
343+
timeseries_parameter_index(ic, sym) !== nothing
345344
end
346345

347346
function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym)
@@ -350,8 +349,13 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy
350349
sym === nothing && return nothing
351350
end
352351
idx = check_index_map(ic.discrete_idx, sym)
352+
idx === nothing ||
353+
return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock))
354+
iscall(sym) && operation(sym) == getindex || return nothing
355+
args = arguments(sym)
356+
idx = timeseries_parameter_index(ic, args[1])
353357
idx === nothing && return nothing
354-
return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock))
358+
ParameterIndex(idx.portion, (idx.idx..., args[2:end]...), idx.validate_size)
355359
end
356360

357361
function check_index_map(idxmap, sym)

0 commit comments

Comments
 (0)