diff --git a/src/problems/nonlinearproblem.jl b/src/problems/nonlinearproblem.jl index 9986ee665b..1ed7cda750 100644 --- a/src/problems/nonlinearproblem.jl +++ b/src/problems/nonlinearproblem.jl @@ -67,7 +67,8 @@ end check_length, check_compatibility, expression, kwargs...) kwargs = process_kwargs(sys; kwargs...) - args = (; f, u0, p, ptype = StandardNonlinearProblem()) + ptype = getmetadata(sys, ProblemTypeCtx, StandardNonlinearProblem()) + args = (; f, u0, p, ptype) return maybe_codegen_scimlproblem(expression, NonlinearProblem{iip}, args; kwargs...) end diff --git a/src/problems/odeproblem.jl b/src/problems/odeproblem.jl index e78ff00525..6726322907 100644 --- a/src/problems/odeproblem.jl +++ b/src/problems/odeproblem.jl @@ -77,7 +77,8 @@ end kwargs = process_kwargs( sys; expression, callback, eval_expression, eval_module, kwargs...) - args = (; f, u0, tspan, p, ptype = StandardODEProblem()) + ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem()) + args = (; f, u0, tspan, p, ptype) maybe_codegen_scimlproblem(expression, ODEProblem{iip}, args; kwargs...) end diff --git a/src/systems/system.jl b/src/systems/system.jl index b30c046edb..cd8d711742 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -807,6 +807,16 @@ function SymbolicUtils.setmetadata(sys::AbstractSystem, k::DataType, v) @set sys.metadata = meta end +""" + $(TYPEDSIGNATURES) + +Metadata key for systems containing the `problem_type` to be passed to the problem +constructor, where applicable. For example, if `getmetadata(sys, ProblemTypeCtx, nothing)` +is `CustomType()` then `ODEProblem(sys, ...).problem_type` will be `CustomType()` instead +of `StandardODEProblem`. +""" +struct ProblemTypeCtx end + """ $(TYPEDSIGNATURES) """ diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index 4bacd6a50d..73cbbe9639 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -434,3 +434,11 @@ end @test resid == [0.0] @test resid isa SVector end + +@testset "`ProblemTypeCtx`" begin + @variables x + @mtkcompile sys = System( + [0 ~ x^2 - 4x + 4]; metadata = [ModelingToolkit.ProblemTypeCtx => "A"]) + prob = NonlinearProblem(sys, [x => 1.0]) + @test prob.problem_type == "A" +end diff --git a/test/odesystem.jl b/test/odesystem.jl index 073dbde5b2..c510b98429 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1566,3 +1566,11 @@ end @test !process_running(proc) kill(proc, Base.SIGKILL) end + +@testset "`ProblemTypeCtx`" begin + @variables x(t) + @mtkcompile sys = System( + [D(x) ~ x], t; metadata = [ModelingToolkit.ProblemTypeCtx => "A"]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0)) + @test prob.problem_type == "A" +end