Skip to content

Commit 12ef813

Browse files
Merge pull request #3685 from AayushSabharwal/as/prob-type
feat: allow specifying `problem_type` using system metadata
2 parents 8648610 + 9c47eaa commit 12ef813

File tree

5 files changed

+30
-2
lines changed

5 files changed

+30
-2
lines changed

src/problems/nonlinearproblem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ end
6767
check_length, check_compatibility, expression, kwargs...)
6868

6969
kwargs = process_kwargs(sys; kwargs...)
70-
args = (; f, u0, p, ptype = StandardNonlinearProblem())
70+
ptype = getmetadata(sys, ProblemTypeCtx, StandardNonlinearProblem())
71+
args = (; f, u0, p, ptype)
7172

7273
return maybe_codegen_scimlproblem(expression, NonlinearProblem{iip}, args; kwargs...)
7374
end

src/problems/odeproblem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ end
7777
kwargs = process_kwargs(
7878
sys; expression, callback, eval_expression, eval_module, kwargs...)
7979

80-
args = (; f, u0, tspan, p, ptype = StandardODEProblem())
80+
ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem())
81+
args = (; f, u0, tspan, p, ptype)
8182
maybe_codegen_scimlproblem(expression, ODEProblem{iip}, args; kwargs...)
8283
end
8384

src/systems/system.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,16 @@ function SymbolicUtils.setmetadata(sys::AbstractSystem, k::DataType, v)
804804
@set sys.metadata = meta
805805
end
806806

807+
"""
808+
$(TYPEDSIGNATURES)
809+
810+
Metadata key for systems containing the `problem_type` to be passed to the problem
811+
constructor, where applicable. For example, if `getmetadata(sys, ProblemTypeCtx, nothing)`
812+
is `CustomType()` then `ODEProblem(sys, ...).problem_type` will be `CustomType()` instead
813+
of `StandardODEProblem`.
814+
"""
815+
struct ProblemTypeCtx end
816+
807817
"""
808818
$(TYPEDSIGNATURES)
809819
"""

test/nonlinearsystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,11 @@ end
434434
@test resid == [0.0]
435435
@test resid isa SVector
436436
end
437+
438+
@testset "`ProblemTypeCtx`" begin
439+
@variables x
440+
@mtkcompile sys = System(
441+
[0 ~ x^2 - 4x + 4]; metadata = [ModelingToolkit.ProblemTypeCtx => "A"])
442+
prob = NonlinearProblem(sys, [x => 1.0])
443+
@test prob.problem_type == "A"
444+
end

test/odesystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,3 +1567,11 @@ end
15671567
@test !process_running(proc)
15681568
kill(proc, Base.SIGKILL)
15691569
end
1570+
1571+
@testset "`ProblemTypeCtx`" begin
1572+
@variables x(t)
1573+
@mtkcompile sys = System(
1574+
[D(x) ~ x], t; metadata = [ModelingToolkit.ProblemTypeCtx => "A"])
1575+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
1576+
@test prob.problem_type == "A"
1577+
end

0 commit comments

Comments
 (0)