Skip to content

Commit 21b231a

Browse files
vyuduAayushSabharwal
authored andcommitted
refactor: use new structural_simplify in linearization
1 parent 26292ff commit 21b231a

File tree

2 files changed

+26
-30
lines changed

2 files changed

+26
-30
lines changed

src/linearization.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,8 @@ function linearization_function(sys::AbstractSystem, inputs,
5858
outputs = mapreduce(vcat, outputs; init = []) do var
5959
symbolic_type(var) == ArraySymbolic() ? collect(var) : [var]
6060
end
61-
ssys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs;
62-
simplify,
63-
kwargs...)
61+
ssys = structural_simplify(sys; inputs, outputs, simplify, kwargs...)
62+
diff_idxs, alge_idxs = eq_idxs(ssys)
6463
if zero_dummy_der
6564
dummyder = setdiff(unknowns(ssys), unknowns(sys))
6665
defs = Dict(x => 0.0 for x in dummyder)
@@ -87,9 +86,9 @@ function linearization_function(sys::AbstractSystem, inputs,
8786

8887
p = parameter_values(prob)
8988
t0 = current_time(prob)
90-
inputvals = [p[idx] for idx in input_idxs]
89+
inputvals = [prob.ps[i] for i in inputs]
9190

92-
hp_fun = let fun = h, setter = setp_oop(sys, input_idxs)
91+
hp_fun = let fun = h, setter = setp_oop(sys, inputs)
9392
function hpf(du, input, u, p, t)
9493
p = setter(p, input)
9594
fun(du, u, p, t)
@@ -113,7 +112,7 @@ function linearization_function(sys::AbstractSystem, inputs,
113112
# observed function is a `GeneratedFunctionWrapper` with iip component
114113
h_jac = PreparedJacobian{true}(h, similar(prob.u0, size(outputs)), autodiff,
115114
prob.u0, DI.Constant(p), DI.Constant(t0))
116-
pf_fun = let fun = prob.f, setter = setp_oop(sys, input_idxs)
115+
pf_fun = let fun = prob.f, setter = setp_oop(sys, inputs)
117116
function pff(du, input, u, p, t)
118117
p = setter(p, input)
119118
SciMLBase.ParamJacobianWrapper(fun, t, u)(du, p)
@@ -127,12 +126,24 @@ function linearization_function(sys::AbstractSystem, inputs,
127126
end
128127

129128
lin_fun = LinearizationFunction(
130-
diff_idxs, alge_idxs, input_idxs, length(unknowns(sys)),
129+
diff_idxs, alge_idxs, length(unknowns(sys)),
131130
prob, h, u0 === nothing ? nothing : similar(u0), uf_jac, h_jac, pf_jac,
132131
hp_jac, initializealg, initialization_kwargs)
133132
return lin_fun, sys
134133
end
135134

135+
function eq_idxs(sys::AbstractSystem)
136+
eqs = equations(sys)
137+
alg_start_idx = findfirst(!isdiffeq, eqs)
138+
if alg_start_idx === nothing
139+
alg_start_idx = length(eqs) + 1
140+
end
141+
diff_idxs = 1:(alg_start_idx - 1)
142+
alge_idxs = alg_start_idx:length(eqs)
143+
144+
diff_idxs, alge_idxs
145+
end
146+
136147
"""
137148
$(TYPEDEF)
138149
@@ -192,7 +203,7 @@ A callable struct which linearizes a system.
192203
$(TYPEDFIELDS)
193204
"""
194205
struct LinearizationFunction{
195-
DI <: AbstractVector{Int}, AI <: AbstractVector{Int}, II, P <: ODEProblem,
206+
DI <: AbstractVector{Int}, AI <: AbstractVector{Int}, I, P <: ODEProblem,
196207
H, C, J1, J2, J3, J4, IA <: SciMLBase.DAEInitializationAlgorithm, IK}
197208
"""
198209
The indexes of differential equations in the linearized system.
@@ -206,7 +217,7 @@ struct LinearizationFunction{
206217
The indexes of parameters in the linearized system which represent
207218
input variables.
208219
"""
209-
input_idxs::II
220+
inputs::I
210221
"""
211222
The number of unknowns in the linearized system.
212223
"""
@@ -281,6 +292,7 @@ function (linfun::LinearizationFunction)(u, p, t)
281292
end
282293

283294
fun = linfun.prob.f
295+
input_vals = [linfun.prob.ps[i] for i in linfun.inputs]
284296
if u !== nothing # Handle systems without unknowns
285297
linfun.num_states == length(u) ||
286298
error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))")
@@ -294,15 +306,15 @@ function (linfun::LinearizationFunction)(u, p, t)
294306
end
295307
fg_xz = linfun.uf_jac(u, DI.Constant(p), DI.Constant(t))
296308
h_xz = linfun.h_jac(u, DI.Constant(p), DI.Constant(t))
297-
fg_u = linfun.pf_jac([p[idx] for idx in linfun.input_idxs],
309+
fg_u = linfun.pf_jac(input_vals,
298310
DI.Constant(u), DI.Constant(p), DI.Constant(t))
299311
else
300312
linfun.num_states == 0 ||
301313
error("Number of unknown variables (0) does not match the number of input unknowns ($(length(u)))")
302314
fg_xz = zeros(0, 0)
303-
h_xz = fg_u = zeros(0, length(linfun.input_idxs))
315+
h_xz = fg_u = zeros(0, length(linfun.inputs))
304316
end
305-
h_u = linfun.hp_jac([p[idx] for idx in linfun.input_idxs],
317+
h_u = linfun.hp_jac(input_vals,
306318
DI.Constant(u), DI.Constant(p), DI.Constant(t))
307319
(f_x = fg_xz[linfun.diff_idxs, linfun.diff_idxs],
308320
f_z = fg_xz[linfun.diff_idxs, linfun.alge_idxs],
@@ -482,9 +494,8 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
482494
outputs; simplify = false, allow_input_derivatives = false,
483495
eval_expression = false, eval_module = @__MODULE__,
484496
kwargs...)
485-
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(
486-
sys, inputs, outputs; simplify,
487-
kwargs...)
497+
sys = structural_simplify(sys; inputs, outputs, simplify, kwargs...)
498+
diff_idxs, alge_idxs = eq_idxs(sys)
488499
sts = unknowns(sys)
489500
t = get_iv(sys)
490501
ps = parameters(sys; initial_parameters = true)

src/systems/abstractsystem.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,21 +2486,6 @@ function eliminate_constants(sys::AbstractSystem)
24862486
return sys
24872487
end
24882488

2489-
function io_preprocessing(sys::AbstractSystem, inputs,
2490-
outputs; simplify = false, kwargs...)
2491-
sys, input_idxs = structural_simplify(sys, (inputs, outputs); simplify, kwargs...)
2492-
2493-
eqs = equations(sys)
2494-
alg_start_idx = findfirst(!isdiffeq, eqs)
2495-
if alg_start_idx === nothing
2496-
alg_start_idx = length(eqs) + 1
2497-
end
2498-
diff_idxs = 1:(alg_start_idx - 1)
2499-
alge_idxs = alg_start_idx:length(eqs)
2500-
2501-
sys, diff_idxs, alge_idxs, input_idxs
2502-
end
2503-
25042489
@latexrecipe function f(sys::AbstractSystem)
25052490
return latexify(equations(sys))
25062491
end

0 commit comments

Comments
 (0)