Skip to content

Commit 8648610

Browse files
Merge pull request #3674 from AayushSabharwal/as/v10-promotion
fix: fix type promotion in `late_binding_update_u0_p` with non-dual types
2 parents b1e09a6 + 3d83e08 commit 8648610

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -638,24 +638,44 @@ function SciMLBase.remake_initialization_data(
638638
return SciMLBase.remake_initialization_data(sys, odefn, newu0, t0, newp, newu0, newp)
639639
end
640640

641-
function promote_u0_p(u0, p::MTKParameters, t0)
642-
u0 = DiffEqBase.promote_u0(u0, p.tunable, t0)
643-
u0 = DiffEqBase.promote_u0(u0, p.initials, t0)
641+
promote_type_with_nothing(::Type{T}, ::Nothing) where {T} = T
642+
promote_type_with_nothing(::Type{T}, ::SizedVector{0}) where {T} = T
643+
function promote_type_with_nothing(::Type{T}, ::AbstractArray{T2}) where {T, T2}
644+
promote_type(T, T2)
645+
end
646+
function promote_type_with_nothing(::Type{T}, p::MTKParameters) where {T}
647+
promote_type_with_nothing(promote_type_with_nothing(T, p.tunable), p.initials)
648+
end
644649

645-
if !isempty(p.tunable)
646-
tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
647-
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
648-
end
649-
if !isempty(p.initials)
650-
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
651-
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
650+
promote_with_nothing(::Type, ::Nothing) = nothing
651+
promote_with_nothing(::Type, x::SizedVector{0}) = x
652+
promote_with_nothing(::Type{T}, x::AbstractArray{T}) where {T} = x
653+
function promote_with_nothing(::Type{T}, x::AbstractArray{T2}) where {T, T2}
654+
if ArrayInterface.ismutable(x)
655+
y = similar(x, T)
656+
copyto!(y, x)
657+
return y
658+
else
659+
yT = similar_type(x, T)
660+
return yT(x)
652661
end
653-
654-
return u0, p
662+
end
663+
function promote_with_nothing(::Type{T}, p::MTKParameters) where {T}
664+
tunables = promote_with_nothing(T, p.tunable)
665+
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
666+
initials = promote_with_nothing(T, p.initials)
667+
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
668+
return p
655669
end
656670

657671
function promote_u0_p(u0, p, t0)
658-
return DiffEqBase.promote_u0(u0, p, t0), DiffEqBase.promote_u0(p, u0, t0)
672+
T = Union{}
673+
T = promote_type_with_nothing(T, u0)
674+
T = promote_type_with_nothing(T, p)
675+
676+
u0 = promote_with_nothing(T, u0)
677+
p = promote_with_nothing(T, p)
678+
return u0, p
659679
end
660680

661681
function SciMLBase.late_binding_update_u0_p(

test/initial_values.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,13 @@ end
346346
@test state_values(initdata.initializeprob) isa SVector
347347
@test parameter_values(initdata.initializeprob) isa SVector
348348
end
349+
350+
@testset "Type promotion of `p` works with non-dual types" begin
351+
@variables x(t) y(t)
352+
@mtkcompile sys = System([D(x) ~ x + y, x^3 + y^3 ~ 5], t; guesses = [y => 1.0])
353+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
354+
prob2 = remake(prob; u0 = BigFloat.(prob.u0))
355+
@test prob2.p.initials isa Vector{BigFloat}
356+
sol = solve(prob2)
357+
@test SciMLBase.successful_retcode(sol)
358+
end

0 commit comments

Comments
 (0)