Skip to content

Commit 9d0b7ef

Browse files
Merge pull request #3695 from fchen121/master
DE Transformation (Change of Variables)
2 parents 34a4afd + 96609b3 commit 9d0b7ef

File tree

4 files changed

+284
-1
lines changed

4 files changed

+284
-1
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
297297
hasunit, getunit, hasconnect, getconnect,
298298
hasmisc, getmisc, state_priority
299299
export liouville_transform, change_independent_variable, substitute_component,
300-
add_accumulations, noise_to_brownians, Girsanov_transform
300+
add_accumulations, noise_to_brownians, Girsanov_transform, changeofvariables
301301
export PDESystem
302302
export Differential, expand_derivatives, @derivatives
303303
export Equation, ConstrainedEquation

src/systems/diffeqs/basic_transformations.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,138 @@ function liouville_transform(sys::System; kwargs...)
5353
)
5454
end
5555

56+
"""
57+
$(TYPEDSIGNATURES)
58+
59+
Generates the set of ODEs after change of variables.
60+
61+
62+
Example:
63+
64+
```julia
65+
using ModelingToolkit, OrdinaryDiffEq, Test
66+
67+
# Change of variables: z = log(x)
68+
# (this implies that x = exp(z) is automatically non-negative)
69+
70+
@parameters t α
71+
@variables x(t)
72+
D = Differential(t)
73+
eqs = [D(x) ~ α*x]
74+
75+
tspan = (0., 1.)
76+
u0 = [x => 1.0]
77+
p = [α => -0.5]
78+
79+
@named sys = ODESystem(eqs; defaults=u0)
80+
prob = ODEProblem(sys, [], tspan, p)
81+
sol = solve(prob, Tsit5())
82+
83+
@variables z(t)
84+
forward_subs = [log(x) => z]
85+
backward_subs = [x => exp(z)]
86+
87+
@named new_sys = changeofvariables(sys, forward_subs, backward_subs)
88+
@test equations(new_sys)[1] == (D(z) ~ α)
89+
90+
new_prob = ODEProblem(new_sys, [], tspan, p)
91+
new_sol = solve(new_prob, Tsit5())
92+
93+
@test isapprox(new_sol[x][end], sol[x][end], atol=1e-4)
94+
```
95+
96+
"""
97+
function changeofvariables(
98+
sys::System, iv, forward_subs, backward_subs;
99+
simplify=true, t0=missing, isSDE=false
100+
)
101+
t = iv
102+
103+
old_vars = first.(backward_subs)
104+
new_vars = last.(forward_subs)
105+
106+
# use: f = Y(t, X)
107+
# use: dY = (∂f/∂t + μ∂f/∂x + (1/2)*σ^2*∂2f/∂x2)dt + σ∂f/∂xdW
108+
old_eqs = equations(sys)
109+
neqs = get_noise_eqs(sys)
110+
brownvars = brownians(sys)
111+
112+
113+
if neqs === nothing && length(brownvars) === 0
114+
neqs = ones(1, length(old_eqs))
115+
elseif neqs !== nothing
116+
isSDE = true
117+
neqs = [neqs[i,:] for i in 1:size(neqs,1)]
118+
119+
brownvars = map([Symbol(:B, :_, i) for i in 1:length(neqs[1])]) do name
120+
unwrap(only(@brownians $name))
121+
end
122+
else
123+
isSDE = true
124+
neqs = Vector{Any}[]
125+
for (i, eq) in enumerate(old_eqs)
126+
neq = Any[]
127+
right = eq.rhs
128+
for Bv in brownvars
129+
lin_exp = linear_expansion(right, Bv)
130+
right = lin_exp[2]
131+
push!(neq, lin_exp[1])
132+
end
133+
push!(neqs, neq)
134+
old_eqs[i] = eq.lhs ~ right
135+
end
136+
end
137+
138+
# df/dt = ∂f/∂x dx/dt + ∂f/∂t
139+
dfdt = Symbolics.derivative( first.(forward_subs), t )
140+
∂f∂x = [Symbolics.derivative( first(f_sub), old_var ) for (f_sub, old_var) in zip(forward_subs, old_vars)]
141+
∂2f∂x2 = Symbolics.derivative.( ∂f∂x, old_vars )
142+
new_eqs = Equation[]
143+
144+
for (new_var, ex, first, second) in zip(new_vars, dfdt, ∂f∂x, ∂2f∂x2)
145+
for (eqs, neq) in zip(old_eqs, neqs)
146+
if occursin(value(eqs.lhs), value(ex))
147+
ex = substitute(ex, eqs.lhs => eqs.rhs)
148+
if isSDE
149+
for (noise, B) in zip(neq, brownvars)
150+
ex = ex + 1/2 * noise^2 * second + noise*first*B
151+
end
152+
end
153+
end
154+
end
155+
ex = substitute(ex, Dict(forward_subs))
156+
ex = substitute(ex, Dict(backward_subs))
157+
if simplify
158+
ex = Symbolics.simplify(ex, expand=true)
159+
end
160+
push!(new_eqs, Differential(t)(new_var) ~ ex)
161+
end
162+
163+
defs = get_defaults(sys)
164+
new_defs = Dict()
165+
for f_sub in forward_subs
166+
ex = substitute(first(f_sub), defs)
167+
if !ismissing(t0)
168+
ex = substitute(ex, t => t0)
169+
end
170+
new_defs[last(f_sub)] = ex
171+
end
172+
for para in parameters(sys)
173+
if haskey(defs, para)
174+
new_defs[para] = defs[para]
175+
end
176+
end
177+
178+
@named new_sys = System(vcat(new_eqs, first.(backward_subs) .~ last.(backward_subs)), t;
179+
defaults=new_defs,
180+
observed=observed(sys)
181+
)
182+
if simplify
183+
return mtkcompile(new_sys)
184+
end
185+
return new_sys
186+
end
187+
56188
"""
57189
change_independent_variable(
58190
sys::System, iv, eqs = [];

test/changeofvariables.jl

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
using ModelingToolkit, OrdinaryDiffEq, StochasticDiffEq
2+
using Test, LinearAlgebra
3+
4+
5+
# Change of variables: z = log(x)
6+
# (this implies that x = exp(z) is automatically non-negative)
7+
@independent_variables t
8+
@variables z(t)[1:2, 1:2]
9+
D = Differential(t)
10+
eqs = [D(D(z)) ~ ones(2, 2)]
11+
@mtkcompile sys = System(eqs, t)
12+
@test_nowarn ODEProblem(sys, [z => zeros(2, 2), D(z) => ones(2, 2)], (0.0, 10.0))
13+
14+
@parameters α
15+
@variables x(t)
16+
D = Differential(t)
17+
eqs = [D(x) ~ α*x]
18+
19+
tspan = (0., 1.)
20+
def = [x => 1.0, α => -0.5]
21+
22+
@mtkcompile sys = System(eqs, t;defaults=def)
23+
prob = ODEProblem(sys, [], tspan)
24+
sol = solve(prob, Tsit5())
25+
26+
@variables z(t)
27+
forward_subs = [log(x) => z]
28+
backward_subs = [x => exp(z)]
29+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs)
30+
@test equations(new_sys)[1] == (D(z) ~ α)
31+
32+
new_prob = ODEProblem(new_sys, [], tspan)
33+
new_sol = solve(new_prob, Tsit5())
34+
35+
@test isapprox(new_sol[x][end], sol[x][end], atol=1e-4)
36+
37+
38+
39+
# Riccati equation
40+
@parameters α
41+
@variables x(t)
42+
D = Differential(t)
43+
eqs = [D(x) ~ t^2 + α - x^2]
44+
def = [x=>1., α => 1.]
45+
@mtkcompile sys = System(eqs, t; defaults=def)
46+
47+
@variables z(t)
48+
forward_subs = [t + α/(x+t) => z ]
49+
backward_subs = [ x => α/(z-t) - t]
50+
51+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=true, t0=0.)
52+
# output should be equivalent to
53+
# t^2 + α - z^2 + 2 (but this simplification is not found automatically)
54+
55+
tspan = (0., 1.)
56+
prob = ODEProblem(sys,[],tspan)
57+
new_prob = ODEProblem(new_sys,[],tspan)
58+
59+
sol = solve(prob, Tsit5())
60+
new_sol = solve(new_prob, Tsit5())
61+
62+
@test isapprox(sol[x][end], new_sol[x][end], rtol=1e-4)
63+
64+
65+
# Linear transformation to diagonal system
66+
@independent_variables t
67+
@variables x(t)[1:3]
68+
x = reshape(x, 3, 1)
69+
D = Differential(t)
70+
A = [0. -1. 0.; -0.5 0.5 0.; 0. 0. -1.]
71+
right = A*x
72+
eqs = vec(D.(x) .~ right)
73+
74+
tspan = (0., 10.)
75+
u0 = [x[1] => 1.0, x[2] => 2.0, x[3] => -1.0]
76+
77+
@mtkcompile sys = System(eqs, t; defaults=u0)
78+
prob = ODEProblem(sys,[],tspan)
79+
sol = solve(prob, Tsit5())
80+
81+
T = eigen(A).vectors
82+
T_inv = inv(T)
83+
84+
@variables z(t)[1:3]
85+
z = reshape(z, 3, 1)
86+
forward_subs = vec(T_inv*x .=> z)
87+
backward_subs = vec(x .=> T*z)
88+
89+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=true)
90+
91+
new_prob = ODEProblem(new_sys, [], tspan)
92+
new_sol = solve(new_prob, Tsit5())
93+
94+
# test RHS
95+
new_rhs = [eq.rhs for eq in equations(new_sys)]
96+
new_A = Symbolics.value.(Symbolics.jacobian(new_rhs, z))
97+
A = diagm(eigen(A).values)
98+
A = sortslices(A, dims=1)
99+
new_A = sortslices(new_A, dims=1)
100+
@test isapprox(A, new_A, rtol = 1e-10)
101+
@test isapprox( new_sol[x[1],end], sol[x[1],end], rtol=1e-4)
102+
103+
# Change of variables for sde
104+
noise_eqs = ModelingToolkit.get_noise_eqs
105+
value = ModelingToolkit.value
106+
107+
@independent_variables t
108+
@brownians B
109+
@parameters μ σ
110+
@variables x(t) y(t)
111+
D = Differential(t)
112+
eqs = [D(x) ~ μ*x + σ*x*B]
113+
114+
def = [x=>0., μ => 2., σ=>1.]
115+
@mtkcompile sys = System(eqs, t; defaults=def)
116+
forward_subs = [log(x) => y]
117+
backward_subs = [x => exp(y)]
118+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs)
119+
@test equations(new_sys)[1] == (D(y) ~ μ - 1/2*σ^2)
120+
@test noise_eqs(new_sys)[1] === value(σ)
121+
122+
#Multiple Brownian and equations
123+
@independent_variables t
124+
@brownians Bx By
125+
@parameters μ σ α
126+
@variables x(t) y(t) z(t) w(t) u(t) v(t)
127+
D = Differential(t)
128+
eqs = [D(x) ~ μ*x + σ*x*Bx, D(y) ~ α*By, D(u) ~ μ*u + σ*u*Bx + α*u*By]
129+
def = [x=>0., y=> 0., u=>0., μ => 2., σ=>1., α=>3.]
130+
forward_subs = [log(x) => z, y^2 => w, log(u) => v]
131+
backward_subs = [x => exp(z), y => w^.5, u => exp(v)]
132+
133+
@mtkcompile sys = System(eqs, t; defaults=def)
134+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs)
135+
@test equations(new_sys)[1] == (D(z) ~ μ - 1/2*σ^2)
136+
@test equations(new_sys)[2] == (D(w) ~ α^2)
137+
@test equations(new_sys)[3] == (D(v) ~ μ - 1/2*^2 + σ^2))
138+
@test noise_eqs(new_sys)[1,1] === value(σ)
139+
@test noise_eqs(new_sys)[1,2] === value(0)
140+
@test noise_eqs(new_sys)[2,1] === value(0)
141+
@test noise_eqs(new_sys)[2,2] === value(substitute(2*α*y, backward_subs[2]))
142+
@test noise_eqs(new_sys)[3,1] === value(σ)
143+
@test noise_eqs(new_sys)[3,2] === value(α)
144+
145+
# Test for Brownian instead of noise
146+
@named sys = System(eqs, t; defaults=def)
147+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=false)
148+
@test simplify(equations(new_sys)[1]) == simplify((D(z) ~ μ - 1/2*σ^2 + σ*Bx))
149+
@test simplify(equations(new_sys)[2]) == simplify((D(w) ~ α^2 + 2*α*w^.5*By))
150+
@test simplify(equations(new_sys)[3]) == simplify((D(v) ~ μ - 1/2*^2 + σ^2) + σ*Bx + α*By))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ end
4747
@safetestset "Error Handling" include("error_handling.jl")
4848
@safetestset "StructuralTransformations" include("structural_transformation/runtests.jl")
4949
@safetestset "Basic transformations" include("basic_transformations.jl")
50+
@safetestset "Change of variables" include("changeofvariables.jl")
5051
@safetestset "State Selection Test" include("state_selection.jl")
5152
@safetestset "Symbolic Event Test" include("symbolic_events.jl")
5253
@safetestset "Stream Connect Test" include("stream_connectors.jl")

0 commit comments

Comments
 (0)