Skip to content

Commit 70e25e8

Browse files
vyuduAayushSabharwal
authored andcommitted
fix input output tests
1 parent c124625 commit 70e25e8

File tree

6 files changed

+37
-23
lines changed

6 files changed

+37
-23
lines changed

src/inputoutput.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
197197
simplify = false,
198198
eval_expression = false,
199199
eval_module = @__MODULE__,
200+
check_simplified = true,
200201
kwargs...)
201-
202202
# Remove this when the ControlFunction gets merged.
203-
if !iscomplete(sys)
203+
if check_simplified && !iscomplete(sys)
204204
error("A completed `ODESystem` is required. Call `complete` or `structural_simplify` on the system before creating the control function.")
205205
end
206206
isempty(inputs) && @warn("No unbound inputs were found in system.")
@@ -257,7 +257,7 @@ end
257257
"""
258258
Turn input variables into parameters of the system.
259259
"""
260-
function inputs_to_parameters!(state::TransformationState, inputsyms)
260+
function inputs_to_parameters!(state::TransformationState, inputsyms; is_disturbance = false)
261261
check_bound = inputsyms === nothing
262262
@unpack structure, fullvars, sys = state
263263
@unpack var_to_diff, graph, solvable_graph = structure
@@ -412,7 +412,7 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[]; kwar
412412
@variables u(t)=0 [input = true] # New system input
413413
dsys = get_disturbance_system(dist)
414414

415-
if inputs === nothing
415+
if isempty(inputs)
416416
all_inputs = [u]
417417
else
418418
i = findfirst(isequal(dist.input), inputs)
@@ -427,8 +427,9 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[]; kwar
427427
dist.input ~ u + dsys.output.u[1]]
428428
augmented_sys = ODESystem(eqs, t, systems = [dsys], name = gensym(:outer))
429429
augmented_sys = extend(augmented_sys, sys)
430+
ssys = structural_simplify(augmented_sys, inputs = all_inputs, disturbance_inputs = [d])
430431

431-
f, dvs, p, io_sys = generate_control_function(augmented_sys, all_inputs,
432-
[d]; kwargs...)
432+
f, dvs, p, io_sys = generate_control_function(ssys, all_inputs,
433+
[d]; check_simplified = false, kwargs...)
433434
f, augmented_sys, dvs, p, io_sys
434435
end

src/linearization.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,11 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
556556
(; A, B, C, D, f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u), sys
557557
end
558558

559-
function markio!(state, orig_inputs, inputs, outputs; check = true)
559+
function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true)
560560
fullvars = get_fullvars(state)
561561
inputset = Dict{Any, Bool}(i => false for i in inputs)
562562
outputset = Dict{Any, Bool}(o => false for o in outputs)
563+
disturbanceset = Dict{Any, Bool}(d => false for d in disturbances)
563564
for (i, v) in enumerate(fullvars)
564565
if v in keys(inputset)
565566
if v in keys(outputset)
@@ -581,6 +582,12 @@ function markio!(state, orig_inputs, inputs, outputs; check = true)
581582
v = setio(v, false, false)
582583
fullvars[i] = v
583584
end
585+
586+
if v in keys(disturbanceset)
587+
v = setio(v, true, false)
588+
v = setdisturbance(v, true)
589+
fullvars[i] = v
590+
end
584591
end
585592
if check
586593
ikeys = keys(filter(!last, inputset))

src/systems/systemstructure.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -727,8 +727,8 @@ function _structural_simplify!(state::TearingState; simplify = false,
727727
has_io = inputs !== nothing || outputs !== nothing
728728
orig_inputs = Set()
729729
if has_io
730-
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs)
731-
state = ModelingToolkit.inputs_to_parameters!(state, inputs)
730+
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
731+
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
732732
end
733733
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
734734
if check_consistency

src/variables.jl

+2
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ function isdisturbance(x)
354354
Symbolics.getmetadata(x, VariableDisturbance, false)
355355
end
356356

357+
setdisturbance(x, v) = setmetadata(x, VariableDisturbance, v)
358+
357359
function disturbances(sys)
358360
[filter(isdisturbance, unknowns(sys)); filter(isdisturbance, parameters(sys))]
359361
end

test/input_output_handling.jl

+17-13
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
77
@variables xx(t) some_input(t) [input = true]
88
eqs = [D(xx) ~ some_input]
99
@named model = ODESystem(eqs, t)
10-
@test_throws ExtraVariablesSystemException structural_simplify(model, ((), ()))
10+
@test_throws ExtraVariablesSystemException structural_simplify(model)
1111
if VERSION >= v"1.8"
1212
err = "In particular, the unset input(s) are:\n some_input(t)"
13-
@test_throws err structural_simplify(model, ((), ()))
13+
@test_throws err structural_simplify(model)
1414
end
1515

1616
# Test input handling
@@ -88,7 +88,7 @@ fsys4 = flatten(sys4)
8888
@variables x(t) y(t) [output = true]
8989
@test isoutput(y)
9090
@named sys = ODESystem([D(x) ~ -x, y ~ x], t) # both y and x are unbound
91-
syss = structural_simplify(sys) # This makes y an observed variable
91+
syss = structural_simplify(sys, outputs = [y]) # This makes y an observed variable
9292

9393
@named sys2 = ODESystem([D(x) ~ -sys.x, y ~ sys.y], t, systems = [sys])
9494

@@ -106,7 +106,7 @@ syss = structural_simplify(sys) # This makes y an observed variable
106106
@test isequal(unbound_outputs(sys2), [y])
107107
@test isequal(bound_outputs(sys2), [sys.y])
108108

109-
syss = structural_simplify(sys2)
109+
syss = structural_simplify(sys2, outputs = [sys.y])
110110

111111
@test !is_bound(syss, y)
112112
@test !is_bound(syss, x)
@@ -165,6 +165,7 @@ end
165165
]
166166

167167
@named sys = ODESystem(eqs, t)
168+
sys = structural_simplify(sys, inputs = [u])
168169
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split)
169170

170171
@test isequal(dvs[], x)
@@ -182,8 +183,8 @@ end
182183
]
183184

184185
@named sys = ODESystem(eqs, t)
185-
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(
186-
sys, [u], [d]; simplify, split)
186+
sys = structural_simplify(sys, inputs = [u], disturbance_inputs = [d])
187+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split)
187188

188189
@test isequal(dvs[], x)
189190
@test isempty(ps)
@@ -200,8 +201,9 @@ end
200201
]
201202

202203
@named sys = ODESystem(eqs, t)
204+
sys = structural_simplify(sys, inputs = [u], disturbance_inputs = [d])
203205
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(
204-
sys, [u], [d]; simplify, split, disturbance_argument = true)
206+
sys; simplify, split, disturbance_argument = true)
205207

206208
@test isequal(dvs[], x)
207209
@test isempty(ps)
@@ -265,9 +267,9 @@ eqs = [connect_sd(sd, mass1, mass2)
265267
@named _model = ODESystem(eqs, t)
266268
@named model = compose(_model, mass1, mass2, sd);
267269

270+
model = structural_simplify(model, inputs = [u])
268271
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(model, simplify = true)
269272
@test length(dvs) == 4
270-
@test length(ps) == length(parameters(model))
271273
p = MTKParameters(io_sys, [io_sys.u => NaN])
272274
x = ModelingToolkit.varmap_to_vars(
273275
merge(ModelingToolkit.defaults(model),
@@ -389,7 +391,7 @@ sys = structural_simplify(model)
389391

390392
## Disturbance models when plant has multiple inputs
391393
using ModelingToolkit, LinearAlgebra
392-
using ModelingToolkit: DisturbanceModel, io_preprocessing, get_iv, get_disturbance_system
394+
using ModelingToolkit: DisturbanceModel, get_iv, get_disturbance_system
393395
using ModelingToolkitStandardLibrary.Blocks
394396
A, C = [randn(2, 2) for i in 1:2]
395397
B = [1.0 0; 0 1.0]
@@ -433,6 +435,7 @@ matrices = ModelingToolkit.reorder_unknowns(
433435
]
434436

435437
@named sys = ODESystem(eqs, t)
438+
sys = structural_simplify(sys, inputs = [u])
436439
(; io_sys,) = ModelingToolkit.generate_control_function(sys, simplify = true)
437440
obsfn = ModelingToolkit.build_explicit_observed_function(
438441
io_sys, [x + u * t]; inputs = [u])
@@ -444,18 +447,19 @@ end
444447
@constants c = 2.0
445448
@variables x(t)
446449
eqs = [D(x) ~ c * x]
447-
@named sys = ODESystem(eqs, t, [x], [])
450+
@mtkbuild sys = ODESystem(eqs, t, [x], [])
448451

449-
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)
450-
@test f([0.5], nothing, MTKParameters(io_sys, []), 0.0) [1.0]
452+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys)
453+
@test f[1]([0.5], nothing, MTKParameters(io_sys, []), 0.0) [1.0]
451454
end
452455

453456
@testset "With callable symbolic" begin
454457
@variables x(t)=0 u(t)=0 [input = true]
455458
@parameters p(::Real) = (x -> 2x)
456459
eqs = [D(x) ~ -x + p(u)]
457460
@named sys = ODESystem(eqs, t)
458-
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)
461+
sys = structural_simplify(sys, inputs = [u])
462+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys)
459463
p = MTKParameters(io_sys, [])
460464
u = [1.0]
461465
x = [1.0]

test/reduction.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ eqs = [D(x) ~ σ * (y - x)
233233
u ~ z + a]
234234

235235
lorenz1 = ODESystem(eqs, t, name = :lorenz1)
236-
lorenz1_reduced, _ = structural_simplify(lorenz1, inputs = [z], outputs = [])
236+
lorenz1_reduced = structural_simplify(lorenz1, inputs = [z], outputs = [])
237237
@test z in Set(parameters(lorenz1_reduced))
238238

239239
# #2064

0 commit comments

Comments
 (0)