Skip to content

Commit a35e1e4

Browse files
Merge pull request #3429 from SciML/fb/linextras
output extra information from linearization
2 parents d7854ad + 77dad69 commit a35e1e4

File tree

4 files changed

+29
-20
lines changed

4 files changed

+29
-20
lines changed

src/linearization.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ end
275275
"""
276276
$(TYPEDSIGNATURES)
277277
278-
Linearize the wrapped system at the point given by `(u, p, t)`.
278+
Linearize the wrapped system at the point given by `(unknowns, p, t)`.
279279
"""
280280
function (linfun::LinearizationFunction)(u, p, t)
281281
if eltype(p) <: Pair
@@ -301,7 +301,7 @@ function (linfun::LinearizationFunction)(u, p, t)
301301
linfun.prob, integ, fun, linfun.initializealg, Val(true);
302302
linfun.initialize_kwargs...)
303303
if !success
304-
error("Initialization algorithm $(linfun.initializealg) failed with `u = $u` and `p = $p`.")
304+
error("Initialization algorithm $(linfun.initializealg) failed with `unknowns = $u` and `p = $p`.")
305305
end
306306
fg_xz = linfun.uf_jac(u, DI.Constant(p), DI.Constant(t))
307307
h_xz = linfun.h_jac(u, DI.Constant(p), DI.Constant(t))
@@ -323,7 +323,10 @@ function (linfun::LinearizationFunction)(u, p, t)
323323
g_u = fg_u[linfun.alge_idxs, :],
324324
h_x = h_xz[:, linfun.diff_idxs],
325325
h_z = h_xz[:, linfun.alge_idxs],
326-
h_u = h_u)
326+
h_u = h_u,
327+
x = u,
328+
p,
329+
t)
327330
end
328331

329332
"""
@@ -436,7 +439,7 @@ function CommonSolve.solve(prob::LinearizationProblem; allow_input_derivatives =
436439
p = parameter_values(prob)
437440
t = current_time(prob)
438441
linres = prob.f(u0, p, t)
439-
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
442+
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u, x, p, t = linres
440443

441444
nx, nu = size(f_u)
442445
nz = size(f_z, 2)
@@ -473,7 +476,7 @@ function CommonSolve.solve(prob::LinearizationProblem; allow_input_derivatives =
473476
end
474477
end
475478

476-
(; A, B, C, D)
479+
(; A, B, C, D), (; x, p, t)
477480
end
478481

479482
"""
@@ -618,8 +621,8 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true
618621
end
619622

620623
"""
621-
(; A, B, C, D), simplified_sys = linearize(sys, inputs, outputs; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false, kwargs...)
622-
(; A, B, C, D) = linearize(simplified_sys, lin_fun; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false)
624+
(; A, B, C, D), simplified_sys, extras = linearize(sys, inputs, outputs; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false, kwargs...)
625+
(; A, B, C, D), extras = linearize(simplified_sys, lin_fun; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false)
623626
624627
Linearize `sys` between `inputs` and `outputs`, both vectors of variables. Return a NamedTuple with the matrices of a linear statespace representation
625628
on the form
@@ -641,6 +644,8 @@ If `allow_input_derivatives = false`, an error will be thrown if input derivativ
641644
642645
`zero_dummy_der` can be set to automatically set the operating point to zero for all dummy derivatives.
643646
647+
The return value `extras` is a NamedTuple `(; x, p, t)` containing the result of the initialization problem that was solved to determine the operating point.
648+
644649
See also [`linearization_function`](@ref) which provides a lower-level interface, [`linearize_symbolic`](@ref) and [`ModelingToolkit.reorder_unknowns`](@ref).
645650
646651
See extended help for an example.
@@ -750,7 +755,8 @@ function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,
750755
zero_dummy_der,
751756
op,
752757
kwargs...)
753-
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys
758+
mats, extras = linearize(ssys, lin_fun; op, t, allow_input_derivatives)
759+
mats, ssys, extras
754760
end
755761

756762
"""

src/systems/analysis_points.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,8 @@ for f in [:get_sensitivity, :get_comp_sensitivity, :get_looptransfer]
908908
sys, ap, args...; loop_openings = [], system_modifier = identity, kwargs...)
909909
lin_fun, ssys = $(utility_fun)(
910910
sys, ap, args...; loop_openings, system_modifier, kwargs...)
911-
ModelingToolkit.linearize(ssys, lin_fun), ssys
911+
mats, extras = ModelingToolkit.linearize(ssys, lin_fun)
912+
mats, ssys, extras
912913
end
913914
end
914915

test/downstream/inversemodel.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,28 +132,28 @@ sol = solve(prob, Rodas5P())
132132

133133
# we need to provide `op` so the initialization system knows what to hold constant
134134
# the values don't matter
135-
Sf, simplified_sys = Blocks.get_sensitivity_function(model, :y; op); # This should work without providing an operating opint containing a dummy derivative
135+
Sf, simplified_sys = get_sensitivity_function(model, :y; op); # This should work without providing an operating opint containing a dummy derivative
136136
x = state_values(Sf)
137137
p = parameter_values(Sf)
138138
# If this somehow passes, mention it on
139139
# https://github.com/SciML/ModelingToolkit.jl/issues/2786
140140
matrices1 = Sf(x, p, 0)
141-
matrices2, _ = Blocks.get_sensitivity(model, :y; op); # Test that we get the same result when calling the higher-level API
141+
matrices2, _ = get_sensitivity(model, :y; op); # Test that we get the same result when calling the higher-level API
142142
@test matrices1.f_x matrices2.A[1:6, 1:6]
143143
nsys = get_named_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
144144
@test matrices2.A nsys.A
145145

146146
# Test the same thing for comp sensitivities
147147

148148
# This should work without providing an operating opint containing a dummy derivative
149-
Sf, simplified_sys = Blocks.get_comp_sensitivity_function(model, :y; op);
149+
Sf, simplified_sys = get_comp_sensitivity_function(model, :y; op);
150150
x = state_values(Sf)
151151
p = parameter_values(Sf)
152152
# If this somehow passes, mention it on
153153
# https://github.com/SciML/ModelingToolkit.jl/issues/2786
154154
matrices1 = Sf(x, p, 0)
155155
# Test that we get the same result when calling the higher-level API
156-
matrices2, _ = Blocks.get_comp_sensitivity(model, :y; op)
156+
matrices2, _ = get_comp_sensitivity(model, :y; op)
157157
@test matrices1.f_x matrices2.A[1:6, 1:6]
158158
# Test that we get the same result when calling an even higher-level API
159159
nsys = get_named_comp_sensitivity(model, :y; op)
@@ -173,15 +173,16 @@ nsys = get_named_comp_sensitivity(model, :y; op)
173173
output = :y
174174
# we need to provide `op` so the initialization system knows which
175175
# values to hold constant
176-
lin_fun, ssys = Blocks.get_sensitivity_function(model, output; op = op1)
177-
matrices1 = linearize(ssys, lin_fun, op = op1)
178-
matrices2 = linearize(ssys, lin_fun, op = op2)
176+
lin_fun, ssys = get_sensitivity_function(model, output; op = op1)
177+
matrices1, extras1 = linearize(ssys, lin_fun, op = op1)
178+
matrices2, extras2 = linearize(ssys, lin_fun, op = op2)
179+
@test extras1.x != extras2.x
179180
S1f = ss(matrices1...)
180181
S2f = ss(matrices2...)
181182
@test S1f != S2f
182183

183-
matrices1, ssys = Blocks.get_sensitivity(model, output; op = op1)
184-
matrices2, ssys = Blocks.get_sensitivity(model, output; op = op2)
184+
matrices1, ssys = get_sensitivity(model, output; op = op1)
185+
matrices2, ssys = get_sensitivity(model, output; op = op2)
185186
S1 = ss(matrices1...)
186187
S2 = ss(matrices2...)
187188
@test S1 != S2

test/linearize.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@ eqs = [u ~ kp * (r - y)
1414

1515
@named sys = System(eqs, t)
1616

17-
lsys, ssys = linearize(sys, [r], [y])
17+
lsys, ssys, extras = linearize(sys, [r], [y])
1818
lprob = LinearizationProblem(sys, [r], [y])
19-
lsys2 = solve(lprob)
19+
lsys2, extras2 = solve(lprob)
2020
lsys3, _ = linearize(sys, [r], [y]; autodiff = AutoFiniteDiff())
2121

2222
@test lsys.A[] == lsys2.A[] == lsys3.A[] == -2
2323
@test lsys.B[] == lsys2.B[] == lsys3.B[] == 1
2424
@test lsys.C[] == lsys2.C[] == lsys3.C[] == 1
2525
@test lsys.D[] == lsys2.D[] == lsys3.D[] == 0
26+
@test extras == extras2
2627

2728
lsys, ssys = linearize(sys, [r], [r])
2829

0 commit comments

Comments
 (0)