Skip to content

Commit 26292ff

Browse files
vyuduAayushSabharwal
authored andcommitted
refactor: remove input_idxs output
1 parent faa5bbf commit 26292ff

File tree

4 files changed

+55
-46
lines changed

4 files changed

+55
-46
lines changed

src/inputoutput.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ The return values also include the chosen state-realization (the remaining unkno
179179
180180
If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will (by default) not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement. To add an input argument corresponding to the disturbance inputs, either include the disturbance inputs among the control inputs, or set `disturbance_argument=true`, in which case an additional input argument `w` is added to the generated function `(x,u,p,t,w)->rhs`.
181181
182-
!!! note "Un-simplified system"
183-
This function expects `sys` to be un-simplified, i.e., `structural_simplify` or `@mtkbuild` should not be called on the system before passing it into this function. `generate_control_function` calls a special version of `structural_simplify` internally.
184-
185182
# Example
186183
187184
```
@@ -201,17 +198,17 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
201198
eval_expression = false,
202199
eval_module = @__MODULE__,
203200
kwargs...)
204-
isempty(inputs) && @warn("No unbound inputs were found in system.")
205201

202+
# Remove this when the ControlFunction gets merged.
203+
if !iscomplete(sys)
204+
error("A completed `ODESystem` is required. Call `complete` or `structural_simplify` on the system before creating the control function.")
205+
end
206+
isempty(inputs) && @warn("No unbound inputs were found in system.")
206207
if disturbance_inputs !== nothing
207208
# add to inputs for the purposes of io processing
208209
inputs = [inputs; disturbance_inputs]
209210
end
210211

211-
if !iscomplete(sys)
212-
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
213-
end
214-
215212
dvs = unknowns(sys)
216213
ps = parameters(sys; initial_parameters = true)
217214
ps = setdiff(ps, inputs)
@@ -257,8 +254,11 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
257254
(; f, dvs, ps, io_sys = sys)
258255
end
259256

260-
function inputs_to_parameters!(state::TransformationState, io)
261-
check_bound = io === nothing
257+
"""
258+
Turn input variables into parameters of the system.
259+
"""
260+
function inputs_to_parameters!(state::TransformationState, inputsyms)
261+
check_bound = inputsyms === nothing
262262
@unpack structure, fullvars, sys = state
263263
@unpack var_to_diff, graph, solvable_graph = structure
264264
@assert solvable_graph === nothing

src/systems/clock_inference.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
struct ClockInference{S}
2+
"""Tearing state."""
23
ts::S
4+
"""The time domain (discrete clock, continuous) of each equation."""
35
eq_domain::Vector{TimeDomain}
6+
"""The output time domain (discrete clock, continuous) of each variable."""
47
var_domain::Vector{TimeDomain}
8+
"""The set of variables with concrete domains."""
59
inferred::BitSet
610
end
711

@@ -67,6 +71,9 @@ function substitute_sample_time(ex, dt)
6771
end
6872
end
6973

74+
"""
75+
Update the equation-to-time domain mapping by inferring the time domain from the variables.
76+
"""
7077
function infer_clocks!(ci::ClockInference)
7178
@unpack ts, eq_domain, var_domain, inferred = ci
7279
@unpack var_to_diff, graph = ts.structure
@@ -132,6 +139,9 @@ function is_time_domain_conversion(v)
132139
input_timedomain(o) != output_timedomain(o)
133140
end
134141

142+
"""
143+
For multi-clock systems, create a separate system for each clock in the system, along with associated equations. Return the updated tearing state, and the sets of clocked variables associated with each time domain.
144+
"""
135145
function split_system(ci::ClockInference{S}) where {S}
136146
@unpack ts, eq_domain, var_domain, inferred = ci
137147
fullvars = get_fullvars(ts)
@@ -143,11 +153,14 @@ function split_system(ci::ClockInference{S}) where {S}
143153
cid_to_eq = Vector{Int}[]
144154
var_to_cid = Vector{Int}(undef, ndsts(graph))
145155
cid_to_var = Vector{Int}[]
156+
# cid_counter = number of clocks
146157
cid_counter = Ref(0)
147158
for (i, d) in enumerate(eq_domain)
148159
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
149160
continuous_id = continuous_id
150161

162+
# Fill the clock_to_id dict as you go,
163+
# ContinuousClock() => 1, ...
151164
get!(clock_to_id, d) do
152165
cid = (cid_counter[] += 1)
153166
push!(id_to_clock, d)

src/systems/systems.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ topological sort of the observed equations in `sys`.
2727
+ `sort_eqs=true` controls whether equations are sorted lexicographically before simplification or not.
2828
"""
2929
function structural_simplify(
30-
sys::AbstractSystem, io = nothing; additional_passes = [], simplify = false, split = true,
30+
sys::AbstractSystem; additional_passes = [], simplify = false, split = true,
3131
allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true,
32+
inputs = nothing, outputs = nothing,
33+
disturbance_inputs = nothing,
3234
kwargs...)
3335
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
34-
newsys′ = __structural_simplify(sys, io; simplify,
36+
newsys′ = __structural_simplify(sys; simplify,
3537
allow_symbolic, allow_parameter, conservative, fully_determined,
38+
inputs, outputs, disturbance_inputs,
3639
kwargs...)
3740
if newsys′ isa Tuple
3841
@assert length(newsys′) == 2
@@ -70,8 +73,9 @@ function __structural_simplify(sys::SDESystem, args...; kwargs...)
7073
return __structural_simplify(ODESystem(sys), args...; kwargs...)
7174
end
7275

73-
function __structural_simplify(
74-
sys::AbstractSystem, io = nothing; simplify = false, sort_eqs = true,
76+
function __structural_simplify(sys::AbstractSystem; simplify = false,
77+
inputs = Any[], outputs = Any[],
78+
disturbance_inputs = Any[],
7579
kwargs...)
7680
sys = expand_connections(sys)
7781
state = TearingState(sys; sort_eqs)
@@ -90,7 +94,8 @@ function __structural_simplify(
9094
end
9195
end
9296
if isempty(brown_vars)
93-
return structural_simplify!(state, io; simplify, kwargs...)
97+
return structural_simplify!(
98+
state; simplify, inputs, outputs, disturbance_inputs, kwargs...)
9499
else
95100
Is = Int[]
96101
Js = Int[]
@@ -122,8 +127,8 @@ function __structural_simplify(
122127
for (i, v) in enumerate(fullvars)
123128
if !iszero(new_idxs[i]) &&
124129
invview(var_to_diff)[i] === nothing]
125-
# TODO: IO is not handled.
126-
ode_sys = structural_simplify(sys, io; simplify, kwargs...)
130+
ode_sys = structural_simplify(
131+
sys; simplify, inputs, outputs, disturbance_inputs, kwargs...)
127132
eqs = equations(ode_sys)
128133
sorted_g_rows = zeros(Num, length(eqs), size(g, 2))
129134
for (i, eq) in enumerate(eqs)

src/systems/systemstructure.jl

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -657,29 +657,21 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
657657
printstyled(io, " SelectedState")
658658
end
659659

660-
# TODO: clean up
661-
function merge_io(io, inputs)
662-
isempty(inputs) && return io
663-
if io === nothing
664-
io = (inputs, [])
665-
else
666-
io = ([inputs; io[1]], io[2])
667-
end
668-
return io
669-
end
670-
671-
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
660+
function structural_simplify!(state::TearingState; simplify = false,
672661
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
662+
inputs = nothing, outputs = nothing,
663+
disturbance_inputs = nothing,
673664
kwargs...)
674665
if state.sys isa ODESystem
675666
ci = ModelingToolkit.ClockInference(state)
676667
ci = ModelingToolkit.infer_clocks!(ci)
677668
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
678669
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
679-
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
680-
cont_io = merge_io(io, inputs[continuous_id])
681-
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
670+
tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
671+
cont_inputs = [inputs; clocked_inputs[continuous_id]]
672+
sys = _structural_simplify!(tss[continuous_id]; simplify,
682673
check_consistency, fully_determined,
674+
cont_inputs, outputs, disturbance_inputs,
683675
kwargs...)
684676
if length(tss) > 1
685677
if continuous_id > 0
@@ -695,8 +687,9 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
695687
discrete_subsystems[i] = sys
696688
continue
697689
end
698-
dist_io = merge_io(io, inputs[i])
699-
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
690+
disc_inputs = [inputs; clocked_inputs[i]]
691+
ss, = _structural_simplify!(state; simplify, check_consistency,
692+
inputs = disc_inputs, outputs, disturbance_inputs,
700693
fully_determined, kwargs...)
701694
append!(appended_parameters, inputs[i], unknowns(ss))
702695
discrete_subsystems[i] = ss
@@ -713,31 +706,29 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
713706
for sym in get_ps(sys)]
714707
@set! sys.ps = ps
715708
else
716-
sys, input_idxs = _structural_simplify!(state, io; simplify, check_consistency,
709+
sys = _structural_simplify!(state; simplify, check_consistency,
710+
inputs, outputs, disturbance_inputs,
717711
fully_determined, kwargs...)
718712
end
719-
has_io = io !== nothing
720-
return has_io ? (sys, input_idxs) : sys
713+
return sys
721714
end
722715

723-
function _structural_simplify!(state::TearingState, io; simplify = false,
716+
function _structural_simplify!(state::TearingState; simplify = false,
724717
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
725718
dummy_derivative = true,
719+
inputs = nothing, outputs = nothing,
720+
disturbance_inputs = nothing,
726721
kwargs...)
727722
if fully_determined isa Bool
728723
check_consistency &= fully_determined
729724
else
730725
check_consistency = true
731726
end
732-
has_io = io !== nothing
727+
has_io = inputs !== nothing || outputs !== nothing
733728
orig_inputs = Set()
734729
if has_io
735-
ModelingToolkit.markio!(state, orig_inputs, io...)
736-
end
737-
if io !== nothing
738-
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
739-
else
740-
input_idxs = 0:-1 # Empty range
730+
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs)
731+
state = ModelingToolkit.inputs_to_parameters!(state, inputs)
741732
end
742733
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
743734
if check_consistency
@@ -761,5 +752,5 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
761752
fullunknowns = [observables(sys); unknowns(sys)]
762753
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullunknowns)
763754

764-
ModelingToolkit.invalidate_cache!(sys), input_idxs
755+
ModelingToolkit.invalidate_cache!(sys)
765756
end

0 commit comments

Comments
 (0)