Skip to content

Commit 1d2a604

Browse files
vyuduAayushSabharwal
authored andcommitted
fix: fix linearization tests
1 parent 21b231a commit 1d2a604

File tree

6 files changed

+23
-25
lines changed

6 files changed

+23
-25
lines changed

src/inputoutput.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
163163
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
164164
sys::AbstractODESystem,
165165
inputs = unbound_inputs(sys),
166-
disturbance_inputs = nothing;
166+
disturbance_inputs = Any[];
167167
implicit_dae = false,
168168
simplify = false,
169169
)
@@ -287,7 +287,7 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
287287
push!(new_fullvars, v)
288288
end
289289
end
290-
ninputs == 0 && return (state, 1:0)
290+
ninputs == 0 && return state
291291

292292
nvars = ndsts(graph) - ninputs
293293
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
@@ -316,14 +316,13 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
316316
@set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
317317
ps = parameters(sys)
318318

319-
if io !== nothing
320-
inputs, = io
319+
if inputsyms !== nothing
321320
# Change order of new parameters to correspond to user-provided order in argument `inputs`
322321
d = Dict{Any, Int}()
323322
for (i, inp) in enumerate(new_parameters)
324323
d[inp] = i
325324
end
326-
permutation = [d[i] for i in inputs]
325+
permutation = [d[i] for i in inputsyms]
327326
new_parameters = new_parameters[permutation]
328327
end
329328

@@ -332,8 +331,7 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
332331
@set! state.sys = sys
333332
@set! state.fullvars = new_fullvars
334333
@set! state.structure = structure
335-
base_params = length(ps)
336-
return state, (base_params + 1):(base_params + length(new_parameters)) # (1:length(new_parameters)) .+ base_params
334+
return state
337335
end
338336

339337
"""
@@ -359,7 +357,7 @@ function get_disturbance_system(dist::DisturbanceModel{<:ODESystem})
359357
end
360358

361359
"""
362-
(f_oop, f_ip), augmented_sys, dvs, p = add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
360+
(f_oop, f_ip), augmented_sys, dvs, p = add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[])
363361
364362
Add a model of an unmeasured disturbance to `sys`. The disturbance model is an instance of [`DisturbanceModel`](@ref).
365363
@@ -408,7 +406,7 @@ model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.i
408406
409407
`f_oop` will have an extra state corresponding to the integrator in the disturbance model. This state will not be affected by any input, but will affect the dynamics from where it enters, in this case it will affect additively from `model.torque.tau.u`.
410408
"""
411-
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kwargs...)
409+
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[]; kwargs...)
412410
t = get_iv(sys)
413411
@variables d(t)=0 [disturbance = true]
414412
@variables u(t)=0 [input = true] # New system input

src/linearization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function linearization_function(sys::AbstractSystem, inputs,
126126
end
127127

128128
lin_fun = LinearizationFunction(
129-
diff_idxs, alge_idxs, length(unknowns(sys)),
129+
diff_idxs, alge_idxs, inputs, length(unknowns(sys)),
130130
prob, h, u0 === nothing ? nothing : similar(u0), uf_jac, h_jac, pf_jac,
131131
hp_jac, initializealg, initialization_kwargs)
132132
return lin_fun, sys

src/systems/diffeqs/odesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ Generates a function that computes the observed value(s) `ts` in the system `sys
470470
- `eval_expression = false`: If true and `expression = false`, evaluates the returned function in the module `eval_module`
471471
- `output_type = Array` the type of the array generated by a out-of-place vector-valued function
472472
- `param_only = false` if true, only allow the generated function to access system parameters
473-
- `inputs = nothing` additinoal symbolic variables that should be provided to the generated function
473+
- `inputs = Any[]` additional symbolic variables that should be provided to the generated function
474474
- `checkbounds = true` checks bounds if true when destructuring parameters
475475
- `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
476476
- `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist.
@@ -500,8 +500,8 @@ For example, a function `g(op, unknowns, p..., inputs, t)` will be the in-place
500500
an array of inputs `inputs` is given, and `param_only` is false for a time-dependent system.
501501
"""
502502
function build_explicit_observed_function(sys, ts;
503-
inputs = nothing,
504-
disturbance_inputs = nothing,
503+
inputs = Any[],
504+
disturbance_inputs = Any[],
505505
disturbance_argument = false,
506506
expression = false,
507507
eval_expression = false,
@@ -574,13 +574,13 @@ function build_explicit_observed_function(sys, ts;
574574
else
575575
(unknowns(sys),)
576576
end
577-
if inputs === nothing
577+
if isempty(inputs)
578578
inputs = ()
579579
else
580580
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
581581
inputs = (inputs,)
582582
end
583-
if disturbance_inputs !== nothing
583+
if !isempty(disturbance_inputs)
584584
# Disturbance inputs may or may not be included as inputs, depending on disturbance_argument
585585
ps = setdiff(ps, disturbance_inputs)
586586
end

src/systems/systems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ topological sort of the observed equations in `sys`.
2929
function structural_simplify(
3030
sys::AbstractSystem; additional_passes = [], simplify = false, split = true,
3131
allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true,
32-
inputs = nothing, outputs = nothing,
33-
disturbance_inputs = nothing,
32+
inputs = Any[], outputs = Any[],
33+
disturbance_inputs = Any[],
3434
kwargs...)
3535
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
3636
newsys′ = __structural_simplify(sys; simplify,

src/systems/systemstructure.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,8 @@ end
659659

660660
function structural_simplify!(state::TearingState; simplify = false,
661661
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
662-
inputs = nothing, outputs = nothing,
663-
disturbance_inputs = nothing,
662+
inputs = Any[], outputs = Any[],
663+
disturbance_inputs = Any[],
664664
kwargs...)
665665
if state.sys isa ODESystem
666666
ci = ModelingToolkit.ClockInference(state)
@@ -671,7 +671,7 @@ function structural_simplify!(state::TearingState; simplify = false,
671671
cont_inputs = [inputs; clocked_inputs[continuous_id]]
672672
sys = _structural_simplify!(tss[continuous_id]; simplify,
673673
check_consistency, fully_determined,
674-
cont_inputs, outputs, disturbance_inputs,
674+
inputs = cont_inputs, outputs, disturbance_inputs,
675675
kwargs...)
676676
if length(tss) > 1
677677
if continuous_id > 0
@@ -716,8 +716,8 @@ end
716716
function _structural_simplify!(state::TearingState; simplify = false,
717717
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
718718
dummy_derivative = true,
719-
inputs = nothing, outputs = nothing,
720-
disturbance_inputs = nothing,
719+
inputs = Any[], outputs = Any[],
720+
disturbance_inputs = Any[],
721721
kwargs...)
722722
if fully_determined isa Bool
723723
check_consistency &= fully_determined

test/downstream/linearize.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ connections = [f.y ~ c.r # filtered reference to controller reference
8787

8888
@named cl = ODESystem(connections, t, systems = [f, c, p])
8989

90-
lsys0, ssys = linearize(cl, [f.u], [p.x])
90+
lsys0, ssys = linearize(cl)
9191
desired_order = [f.x, p.x]
9292
lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(ssys), desired_order)
93-
lsys1, ssys = linearize(cl, [f.u], [p.x]; autodiff = AutoFiniteDiff())
93+
lsys1, ssys = linearize(cl; autodiff = AutoFiniteDiff())
9494
lsys2 = ModelingToolkit.reorder_unknowns(lsys1, unknowns(ssys), desired_order)
9595

9696
@test lsys.A == lsys2.A == [-2 0; 1 -2]
@@ -266,7 +266,7 @@ closed_loop = ODESystem(connections, t, systems = [model, pid, filt, sensor, r,
266266
filt.xd => 0.0
267267
])
268268

269-
@test_nowarn linearize(closed_loop, :r, :y; warn_empty_op = false)
269+
@test_nowarn linearize(closed_loop; warn_empty_op = false)
270270

271271
# https://discourse.julialang.org/t/mtk-change-in-linearize/115760/3
272272
@mtkmodel Tank_noi begin

0 commit comments

Comments
 (0)