Skip to content

Commit 0e6e842

Browse files
feat: make parameter type validation error more descriptive
1 parent 1d2c519 commit 0e6e842

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/systems/parameter_buffer.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,15 @@ function Base.copy(p::MTKParameters)
336336
end
337337

338338
function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex)
339+
_ducktyped_parameter_values(p, pind)
340+
end
341+
function _ducktyped_parameter_values(p, pind::ParameterIndex)
339342
@unpack portion, idx = pind
340343
if portion isa SciMLStructures.Tunable
341344
return idx isa Int ? p.tunable[idx] : view(p.tunable, idx)
342345
end
343346
i, j, k... = idx
344-
if portion isa SciMLStructures.Tunable
345-
return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...]
346-
elseif portion isa SciMLStructures.Discrete
347+
if portion isa SciMLStructures.Discrete
347348
return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...]
348349
elseif portion isa SciMLStructures.Constants
349350
return isempty(k) ? p.constant[i][j] : p.constant[i][j][k...]
@@ -444,11 +445,13 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val)
444445
# Nonnumeric parameters have to match the type
445446
if portion === NONNUMERIC_PORTION
446447
val isa stype && return nothing
447-
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
448+
throw(ParameterTypeException(
449+
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
448450
end
449451
# Array parameters need array values...
450452
if stype <: AbstractArray && !isa(val, AbstractArray)
451-
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
453+
throw(ParameterTypeException(
454+
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
452455
end
453456
# ... and must match sizes
454457
if stype <: AbstractArray && sz != Symbolics.Unknown() && size(val) != sz
@@ -465,15 +468,16 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val)
465468
# This is for duals and other complicated number types
466469
etype = SciMLBase.parameterless_type(etype)
467470
eltype(val) <: etype || throw(ParameterTypeException(
468-
:validate_parameter_type, sym, AbstractArray{etype}, val))
471+
:validate_parameter_type, sym === nothing ? index : sym, AbstractArray{etype}, val))
469472
else
470473
# Real check
471474
if stype <: Real
472475
stype = Real
473476
end
474477
stype = SciMLBase.parameterless_type(stype)
475478
val isa stype ||
476-
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
479+
throw(ParameterTypeException(
480+
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
477481
end
478482
end
479483

0 commit comments

Comments
 (0)