From b685ab5912dfee89682c7167f7c2c4e2afe2afb5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 8 Apr 2025 17:22:25 +0530 Subject: [PATCH 1/5] feat: add `@generated` method for `_remake_buffer` for type-inference --- src/systems/parameter_buffer.jl | 66 +++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 36105c0de8..3122ace245 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -496,6 +496,72 @@ end function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals) _remake_buffer(indp, oldbuf, idxs, vals) end + +# For type-inference when using `SII.setp_oop` +@generated function _remake_buffer( + indp, oldbuf::MTKParameters{T, I, D, C, N, H}, idxs::Tuple{Vararg{ParameterIndex}}, + vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H} + valtype(i) = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i) + tunablesT = eltype(T) + for (i, idxT) in enumerate(fieldtypes(idxs)) + idxT <: ParameterIndex{SciMLStructures.Tunable} || continue + tunablesT = promote_type(tunablesT, valtype(i)) + end + initialsT = eltype(I) + for (i, idxT) in enumerate(fieldtypes(idxs)) + idxT <: ParameterIndex{SciMLStructures.Initials} || continue + initialsT = promote_type(initialsT, valtype(i)) + end + discretesT = ntuple(Val(fieldcount(D))) do i + bufT = eltype(fieldtype(D, i)) + for (j, idxT) in enumerate(fieldtypes(idxs)) + idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue + bufT = promote_type(bufT, valtype(i)) + end + bufT + end + constantsT = ntuple(Val(fieldcount(C))) do i + bufT = eltype(fieldtype(C, i)) + for (j, idxT) in enumerate(fieldtypes(idxs)) + idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue + bufT = promote_type(bufT, valtype(i)) + end + bufT + end + nonnumericT = ntuple(Val(fieldcount(N))) do i + bufT = eltype(fieldtype(N, i)) + for (j, idxT) in enumerate(fieldtypes(idxs)) + idxT <: ParameterIndex{Nonnumeric, i} || continue + bufT = promote_type(bufT, valtype(i)) + end + bufT + end + + expr = quote + tunables = $similar(oldbuf.tunable, $tunablesT) + copyto!(tunables, oldbuf.tunable) + initials = $similar(oldbuf.initials, $initialsT) + copyto!(initials, oldbuf.initials) + discretes = $(Expr(:tuple, + (:($similar(oldbuf.discrete[$i], $(discretesT[i]))) for i in 1:length(discretesT))...)) + $((:($copyto!(discretes[$i], oldbuf.discrete[$i])) for i in 1:length(discretesT))...) + constants = $(Expr(:tuple, + (:($similar(oldbuf.constant[$i], $(constantsT[i]))) for i in 1:length(constantsT))...)) + $((:($copyto!(constants[$i], oldbuf.constant[$i])) for i in 1:length(constantsT))...) + nonnumerics = $(Expr(:tuple, + (:($similar(oldbuf.nonnumeric[$i], $(nonnumericT[i]))) for i in 1:length(nonnumericT))...)) + $((:($copyto!(nonnumerics[$i], oldbuf.nonnumeric[$i])) for i in 1:length(nonnumericT))...) + newbuf = MTKParameters( + tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches)) + end + for i in 1:fieldcount(idxs) + push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i]))) + end + push!(expr.args, :(return newbuf)) + + return expr +end + function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true) newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any) @set! newbuf.initials = similar(oldbuf.initials, Any) From 0ba61201650b7e3a2d3adbb388e7a546a6bb8e63 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 13:04:51 +0530 Subject: [PATCH 2/5] build: bump SII compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 63aef5b8da..693a81e148 100644 --- a/Project.toml +++ b/Project.toml @@ -149,7 +149,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" StochasticDelayDiffEq = "1.8.1" StochasticDiffEq = "6.72.1" -SymbolicIndexingInterface = "0.3.37" +SymbolicIndexingInterface = "0.3.39" SymbolicUtils = "3.25.1" Symbolics = "6.37" URIs = "1" From 96694de0e47208da1e2d1e0f557a9c7f271e6fa2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 13:06:34 +0530 Subject: [PATCH 3/5] feat: handle more type-stable cases in `@generated _remake_buffer` --- src/systems/parameter_buffer.jl | 72 ++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 3122ace245..0916bb1185 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -499,40 +499,76 @@ end # For type-inference when using `SII.setp_oop` @generated function _remake_buffer( - indp, oldbuf::MTKParameters{T, I, D, C, N, H}, idxs::Tuple{Vararg{ParameterIndex}}, - vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H} - valtype(i) = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i) + indp, oldbuf::MTKParameters{T, I, D, C, N, H}, + idxs::Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}}, + vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H, P} + + # fallback to non-generated method if values aren't type-stable + if vals <: AbstractArray && !isconcretetype(eltype(vals)) + return quote + $_remake_buffer(indp, oldbuf, collect(idxs), vals; validate) + end + end + + # given an index in idxs/vals and the current `eltype` of the buffer, + # return the promoted eltype of the buffer + function promote_valtype(i, valT) + # tuples have distinct types, arrays have a common eltype + valT′ = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i) + # if the buffer is a scalarized buffer but the variable is an array + # e.g. an array tunable, take the eltype + if valT′ <: AbstractArray && !(valT <: AbstractArray) + valT′ = eltype(valT′) + end + return promote_type(valT, valT′) + end + + # types of the idxs + idxtypes = if idxs <: AbstractArray + # if both are arrays, there is only one possible type to check + if vals <: AbstractArray + (eltype(idxs),) + else + # if `vals` is a tuple, we repeat `eltype(idxs)` to check against + # every possible type of the buffer + ntuple(Returns(eltype(idxs)), Val(fieldcount(vals))) + end + else + # `idxs` is a tuple, so we check against all buffers + fieldtypes(idxs) + end + # promote types tunablesT = eltype(T) - for (i, idxT) in enumerate(fieldtypes(idxs)) + for (i, idxT) in enumerate(idxtypes) idxT <: ParameterIndex{SciMLStructures.Tunable} || continue - tunablesT = promote_type(tunablesT, valtype(i)) + tunablesT = promote_valtype(i, tunablesT) end initialsT = eltype(I) - for (i, idxT) in enumerate(fieldtypes(idxs)) + for (i, idxT) in enumerate(idxtypes) idxT <: ParameterIndex{SciMLStructures.Initials} || continue - initialsT = promote_type(initialsT, valtype(i)) + initialsT = promote_valtype(i, initialsT) end discretesT = ntuple(Val(fieldcount(D))) do i bufT = eltype(fieldtype(D, i)) - for (j, idxT) in enumerate(fieldtypes(idxs)) + for (j, idxT) in enumerate(idxtypes) idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue - bufT = promote_type(bufT, valtype(i)) + bufT = promote_valtype(i, bufT) end bufT end constantsT = ntuple(Val(fieldcount(C))) do i bufT = eltype(fieldtype(C, i)) - for (j, idxT) in enumerate(fieldtypes(idxs)) + for (j, idxT) in enumerate(idxtypes) idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue - bufT = promote_type(bufT, valtype(i)) + bufT = promote_valtype(i, bufT) end bufT end nonnumericT = ntuple(Val(fieldcount(N))) do i bufT = eltype(fieldtype(N, i)) - for (j, idxT) in enumerate(fieldtypes(idxs)) + for (j, idxT) in enumerate(idxtypes) idxT <: ParameterIndex{Nonnumeric, i} || continue - bufT = promote_type(bufT, valtype(i)) + bufT = promote_valtype(i, bufT) end bufT end @@ -554,8 +590,14 @@ end newbuf = MTKParameters( tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches)) end - for i in 1:fieldcount(idxs) - push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i]))) + if idxs <: AbstractArray + push!(expr.args, :(for (idx, val) in zip(idxs, vals) + $setindex!(newbuf, val, idx) + end)) + else + for i in 1:fieldcount(idxs) + push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i]))) + end end push!(expr.args, :(return newbuf)) From c5e8240458b503f2cfd8e92ad9c0c4b2e73456fb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 13:06:48 +0530 Subject: [PATCH 4/5] test: test type-stability of `remake_buffer` in supported cases --- test/mtkparameters.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index f9d71f00bc..22201b1988 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -91,6 +91,27 @@ end @test getp(sys, f)(newps) isa Matrix{UInt} @test getp(sys, g)(newps) isa Vector{Float32} +@testset "Type-stability of `remake_buffer`" begin + prob = ODEProblem(sys, [], (0.0, 1.0), ivs) + + idxs = (a, c, d, e, f, g, h) + vals = (1.0, 2.0, 3, ones(3), ones(Int, 3, 3), ones(2), "a") + + setter = setsym_oop(prob, idxs) + @test_nowarn @inferred setter(prob, vals) + @test_throws ErrorException @inferred setter(prob, collect(vals)) + + idxs = (a, c, e...) + vals = Float16[1.0, 2.0, 3.0, 4.0, 5.0] + setter = setsym_oop(prob, idxs) + @test_nowarn @inferred setter(prob, vals) + + idxs = [a, e] + vals = (Float16(1.0), ForwardDiff.Dual{Nothing, Float16, 0}[1.0, 2.0, 3.0]) + setter = setsym_oop(prob, idxs) + @test_nowarn @inferred setter(prob, vals) +end + ps = MTKParameters(sys, ivs) function loss(value, sys, ps) @test value isa ForwardDiff.Dual From 06a63df059223f6625e0a5b221102563f53825b4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 13:39:54 +0530 Subject: [PATCH 5/5] fix: fix stack overflow in `@generated _remake_buffer` fallback --- src/systems/parameter_buffer.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 0916bb1185..672fb58795 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -506,7 +506,7 @@ end # fallback to non-generated method if values aren't type-stable if vals <: AbstractArray && !isconcretetype(eltype(vals)) return quote - $_remake_buffer(indp, oldbuf, collect(idxs), vals; validate) + $__remake_buffer(indp, oldbuf, collect(idxs), vals; validate) end end @@ -605,6 +605,10 @@ end end function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true) + return __remake_buffer(indp, oldbuf, idxs, vals; validate) +end + +function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true) newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any) @set! newbuf.initials = similar(oldbuf.initials, Any) @set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)