Skip to content

[WIP] feat: reduce reliance on metadata in structural_simplify #3540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 167 commits into
base: v10
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
167 commits
Select commit Hold shift + click to select a range
8d2e149
refactor: remove old `System` function
AayushSabharwal Apr 14, 2025
2c95e03
refactor: remove `odesystem.jl`
AayushSabharwal May 14, 2025
ca3e57b
refactor: add `_eq_unordered` to `utils.jl`
AayushSabharwal Apr 20, 2025
b07ed51
refactor: move `flatten_equations` to `utils.jl`
AayushSabharwal May 14, 2025
b953ee2
refactor: remove `sdesystem.jl`
AayushSabharwal Apr 20, 2025
d76d6f8
refactor: move `__num_isdiag_noise`, `get_num_diag_noise` to location…
AayushSabharwal Apr 20, 2025
ee1e8db
refactor: remove `discrete_system.jl`
AayushSabharwal Apr 20, 2025
d84fef3
refactor: remove `implicit_discrete_system.jl`
AayushSabharwal Apr 20, 2025
375eea0
remove jumpsystem.jl
AayushSabharwal May 14, 2025
3bf4d8d
refactor: move `JumpType` definition to `utils.jl`
AayushSabharwal Apr 22, 2025
429f58d
refactor: remove `nonlinearsystem.jl`
AayushSabharwal Apr 20, 2025
6767d4c
refactor: remove `optimizationsystem.jl`
AayushSabharwal Apr 21, 2025
bde2c29
refactor: remove `constraints_system.jl`
AayushSabharwal Apr 21, 2025
1aef225
remove abstractodesystem.jl
AayushSabharwal May 14, 2025
b803240
refactor: do not rely on `ArrayPartition` when unit-checking jumps
AayushSabharwal May 14, 2025
c2f3070
feat: add unified `System` type
AayushSabharwal May 14, 2025
02ec43c
feat: add getters for new `System` fields
AayushSabharwal Apr 14, 2025
d33c220
feat: add hierarchical aggregator functions for jumps, brownians and …
AayushSabharwal Apr 14, 2025
016f337
refactor: move `constraints` to `abstractsystem.jl`
AayushSabharwal Apr 21, 2025
a20655f
refactor: don't warn about system supertype for `System`
AayushSabharwal Apr 14, 2025
5246b29
refactor: port `stochastic_integral_transform` and `Girsanov_transform`
AayushSabharwal Apr 20, 2025
8b6a8e6
refactor: move `eval_or_rgf` to `codegen_utils.jl`
AayushSabharwal Apr 27, 2025
3465f41
feat: allow `GeneratedFunctionWrapper` to compile functions and build…
AayushSabharwal Apr 27, 2025
59e513a
feat: add `maybe_compile_function`
AayushSabharwal Apr 28, 2025
0b389f1
refactor: fix and document `delay_to_function`, implement it for `Sys…
AayushSabharwal Apr 15, 2025
e296ef4
feat: add initial codegen for `System`
AayushSabharwal Apr 13, 2025
bbb918c
refactor: port `build_explicit_observed_function` to `codegen.jl`
AayushSabharwal Apr 20, 2025
7228f1e
feat: add `@fallback_iip_specialize`
AayushSabharwal Apr 14, 2025
21eb8a2
feat: add `check_compatible_system`
AayushSabharwal Apr 14, 2025
31bcf21
feat: implement `generate_initializesystem` for `System`
AayushSabharwal Apr 14, 2025
da17dbc
refactor: remove `generate_factorized_W`
AayushSabharwal Apr 16, 2025
21ea3e4
fix: fix `remake` for `IntervalNonlinearProblem`
AayushSabharwal Apr 16, 2025
11b369b
feat: add `replace` kwarg to `add_toterms!`
AayushSabharwal May 14, 2025
0d5a377
refactor: centralize problem kwargs handling
AayushSabharwal May 14, 2025
87f41f6
refactor: move `filter_kwargs` and `SymbolicTstops` to `problem_utils…
AayushSabharwal Apr 21, 2025
4b080c8
feat: support returning `Expr` from `SymbolicTstops` and `ObservedFun…
AayushSabharwal Apr 28, 2025
d088463
refactor: pass `u0` and `p` as kwargs in `process_SciMLProblem`
AayushSabharwal May 14, 2025
d0844c5
feat: add `maybe_codegen_scimlfn` and `maybe_codegen_scimlproblem`
AayushSabharwal Apr 28, 2025
bf9fb81
feat: implement `ODEProblem` and `ODEFunction` for `System`
AayushSabharwal Apr 14, 2025
208fd8c
feat: implement `DDEFunction`, `DDEProblem` for `System`
AayushSabharwal May 14, 2025
3f60b79
feat: implement `DAEProblem` and `DAEFunction` for `System`
AayushSabharwal May 14, 2025
37b0f16
feat: implement `SDEProblem` and `SDEFunction` for `System`
AayushSabharwal May 14, 2025
9bf8d76
feat: implement `SDDEProblem`, `SDDEFunction` for `System`
AayushSabharwal Apr 16, 2025
1dc9073
feat: implement `NonlinearProblem`, `NonlinearFunction` for `System`
AayushSabharwal May 14, 2025
0660676
feat: implement `IntervalNonlinearProblem`, `IntervalNonlinearFunctio…
AayushSabharwal Apr 16, 2025
3a96bb1
feat: implement `ImplicitDiscreteProblem` `ImplicitDiscreteFunction` …
AayushSabharwal May 14, 2025
d119c8b
feat: implement `DiscreteProblem` and `DiscreteFunction` for `System`
AayushSabharwal May 14, 2025
f1999b1
feat: allow manually choosing time-independent initialization
AayushSabharwal Apr 17, 2025
29a829f
feat: implement `OptimizationProblem` and `OptimizationFunction` for …
AayushSabharwal May 14, 2025
c1c1ada
feat: implement `JumpProblem` for `System`
AayushSabharwal May 14, 2025
6ec4561
feat: add `InitializationProblem`
AayushSabharwal May 14, 2025
490965e
feat: implement `NonlinearLeastSquaresProblem` for `System`
AayushSabharwal Apr 20, 2025
9200646
refactor: port `SCCNonlinearProblem` to separate file
AayushSabharwal Apr 20, 2025
43e2455
feat: implement `BVProblem` for `System`
AayushSabharwal May 14, 2025
66bb56b
feat: implement `SteadyStateProblem` for `System`
AayushSabharwal Apr 21, 2025
c184da0
fix: fix `toexpr(::AbstractSystem)`
AayushSabharwal Apr 21, 2025
37979be
fix: fix `extend(::AbstractSystem)`
AayushSabharwal Apr 21, 2025
6105a2d
fix: fix `substitute(::AbstractSystem, _...)`
AayushSabharwal Apr 21, 2025
ba3a5de
fix: construct `System` in `@mtkmodel`
AayushSabharwal Apr 21, 2025
c00265c
docs: fix docstring of `process_SciMLProblem`
AayushSabharwal Apr 21, 2025
6daea8f
refactor: allow real values in `costs`
AayushSabharwal Apr 25, 2025
6765ffd
refactor: remove old clock handling code, retain error messages
AayushSabharwal Apr 21, 2025
063fdf0
fix: allow simplifying systems with noise
AayushSabharwal Apr 21, 2025
1cd1bfa
refactor: remove `schedule(sys)`
AayushSabharwal Apr 21, 2025
a4ede08
feat: set system scheduling information in `structural_simplify`
AayushSabharwal Apr 21, 2025
208002b
refactor: remove references to `ODESystem` in source code
AayushSabharwal Apr 21, 2025
827d04f
refactor: remove `__structural_simplify(::JumpSystem)`
AayushSabharwal Apr 21, 2025
6adf58e
refactor: remove references to `SDESystem` in source code
AayushSabharwal Apr 21, 2025
9fa1c8a
refactor: remove references to `NonlinearSystem`
AayushSabharwal Apr 21, 2025
efeca6b
refactor: remove references to `DiscreteSystem`
AayushSabharwal Apr 21, 2025
774bf46
refactor: do not use `sys.substitutions`
AayushSabharwal Apr 22, 2025
f0c97c1
test: replace `ODESystem` with `System`
AayushSabharwal May 14, 2025
075f00f
test: replace `NonlinearSystem` with `System`
AayushSabharwal Apr 22, 2025
ab21e35
test: replace `ImplicitDiscreteSystem` with `System`
AayushSabharwal Apr 22, 2025
55bbef0
test: replace `DiscreteSystem` with `System`
AayushSabharwal Apr 22, 2025
c3cc065
fix: ensure equations are `Vector{Equation}` in `generate_initializes…
AayushSabharwal Apr 25, 2025
2889896
test: fix usage of array equations in test
AayushSabharwal Apr 25, 2025
9664374
test: ensure equations passed to system are `Vector{Equation}`
AayushSabharwal Apr 25, 2025
84eaaad
refactor: remove `process_equations`
AayushSabharwal Apr 25, 2025
8b3b300
docs: document `collect_var!`
AayushSabharwal Apr 25, 2025
2d03523
refactor: change default operator in `collect_vars!` to `Symbolics.Op…
AayushSabharwal Apr 25, 2025
99345eb
feat: add `validate_operator`
AayushSabharwal Apr 25, 2025
51d5bd0
refactor: document `collect_vars!` and use `validate_operator`
AayushSabharwal Apr 25, 2025
bfa5282
feat: add `is_floatingpoint_symtype`
AayushSabharwal Apr 25, 2025
c2ded8c
refactor: use `is_floatingpoint_symtype` in `is_variable_floatingpoint`
AayushSabharwal Apr 25, 2025
fd6e6aa
test: don't pass dvs/ps to `ODEFunction`
AayushSabharwal Apr 25, 2025
53d7122
test: fix shadowing of `System` in tests
AayushSabharwal Apr 25, 2025
396473a
test: pass `u0map` and `tspan` to `ODEProblem`
AayushSabharwal Apr 25, 2025
61734de
test: create `JumpProblem` directly
AayushSabharwal Apr 25, 2025
773a340
test: pass `Vector{Equation}` to `System`
AayushSabharwal Apr 25, 2025
538270d
refactor: remove `build_torn_function`, `tearing_assignments`
AayushSabharwal Apr 27, 2025
2ce2965
refactor: remove `get_substitutions`, `has_substitutions` field getters
AayushSabharwal Apr 27, 2025
717df75
refactor: implement `empty_substitutions` and `get_substitutions` usi…
AayushSabharwal Apr 27, 2025
81af64c
refactor: remove `get_substitutions_and_solved_unknowns`
AayushSabharwal Apr 27, 2025
88fcaf4
refactor: update `tearing_substitute_expr`, `full_equations` to use `…
AayushSabharwal Apr 27, 2025
b286994
refactor: do not use `get_substitutions` in `get_cmap`
AayushSabharwal Apr 27, 2025
6e1a8a7
test: fix mass matrix tests
AayushSabharwal Apr 27, 2025
dbdce94
test: fix odesystem tests
AayushSabharwal Apr 27, 2025
b2d9aa4
refactor: rename `generate_function` to `generate_rhs`
AayushSabharwal Apr 27, 2025
887859c
fix: fix type-piracy of `Symbolics.rename`
AayushSabharwal Apr 29, 2025
63d0f02
refactor: remove `systems/diffeqs/modelingtoolkitize.jl`
AayushSabharwal Apr 29, 2025
3f6baff
feat: add `modelingtoolkitize` for `ODEProblem`
AayushSabharwal Apr 29, 2025
e06e654
feat: add `modelingtoolkitize` for `SDEProblem`
AayushSabharwal Apr 29, 2025
980c413
feat: add `add_accumulations`
AayushSabharwal Apr 29, 2025
69f16ac
test: fix `odesystem` tests
AayushSabharwal Apr 29, 2025
1528f50
fix: validate that `Sample` operates on unknowns
AayushSabharwal Apr 29, 2025
d24f372
test: fix `test/structural_transformation/utils.jl`
AayushSabharwal Apr 29, 2025
d912b21
test: simplify test for metadata retention in `complete`
AayushSabharwal Apr 29, 2025
3618586
test: improve readability of dependency graph tests
AayushSabharwal Apr 29, 2025
63d4877
test: fix usage of `ODEProblemExpr` in lowering test
AayushSabharwal Apr 29, 2025
db7a7a0
test: remove test for specifying type of system in `@mtkmodel`
AayushSabharwal Apr 29, 2025
1fd04cc
test: fix parameter dependencies test
AayushSabharwal Apr 29, 2025
72a8cdc
test: fix symbolic events test
AayushSabharwal Apr 29, 2025
feee2ac
test: fix modelingtoolkitize test
AayushSabharwal Apr 29, 2025
aab4251
test: remove outdated test
AayushSabharwal May 1, 2025
c29367b
refactor: remove old `modelingtoolkitize(::OptimizationProblem)`
AayushSabharwal May 1, 2025
02b7dde
feat: add `modelingtoolkitize(::OptimizationProblem)`
AayushSabharwal May 1, 2025
52316ba
refactor: remove old `modelingtoolkitize(::NonlinearProblem)`
AayushSabharwal May 1, 2025
3e4398e
feat: add `modelingtoolkitize(::NonlinearProblem)`
AayushSabharwal May 1, 2025
0dd372d
refactor: remove old `modelingtoolkitize(::ODEProblem)` and `::SDEPro…
AayushSabharwal May 1, 2025
e818644
feat: add `structural_simplify` for optimization systems
AayushSabharwal May 2, 2025
6421536
fix: change default `consolidate` to `default_consolidate` in `@mtkmo…
AayushSabharwal May 7, 2025
f771f34
fix: handle `Shift`s in `is_diff_equation`
AayushSabharwal May 8, 2025
e95b1ae
feat: implement `calculate_hessian` and `hessian_sparsity` for `System`
AayushSabharwal May 8, 2025
1fbad40
feat: make delay processing more modular in `build_function_wrapper`
AayushSabharwal May 8, 2025
f681588
fix: fix `generate_cost` for time-dependent (BV) systems
AayushSabharwal May 8, 2025
78bdb60
fix: default `build_initializeprob` to `supports_initialization(sys)`
AayushSabharwal May 8, 2025
71faa50
fix: disallow simplification of jump systems
AayushSabharwal May 8, 2025
ae6aa08
fix: remove usage of `is_scalar_noise` kwarg
AayushSabharwal May 8, 2025
d71efa2
feat: inspect jumps in `collect_scoped_vars!`
AayushSabharwal May 8, 2025
a234bd8
test: comment out jac/tgrad caching test
AayushSabharwal May 8, 2025
00b510b
test: update BVProblem tests to new semantics
AayushSabharwal May 8, 2025
725ee73
test: update DDE tests to new semantics
AayushSabharwal May 8, 2025
3709b32
test: make debugging tests more reproducible
AayushSabharwal May 8, 2025
b64a308
test: update discrete system tests to new semantics
AayushSabharwal May 8, 2025
7fe6bb2
test: update implicit discrete system tests to new semantics
AayushSabharwal May 8, 2025
f08f045
test: remove redundant initial values tests
AayushSabharwal May 8, 2025
5400c1d
test: update initialization system tests to new semantics
AayushSabharwal May 8, 2025
ab1ce1d
test: fix jump system tests
AayushSabharwal May 8, 2025
6d1dea7
test: remove redundant namespacing tests
AayushSabharwal May 8, 2025
ba66b67
fix: fix `independent_variables`
AayushSabharwal May 8, 2025
5d33972
feat: export newly added functions
AayushSabharwal May 8, 2025
7554962
feat: add `noise_to_brownians`
AayushSabharwal May 14, 2025
c3c65fb
feat: add `convert_system_indepvar`
AayushSabharwal May 14, 2025
cba0860
fix: fix structural simplification for SDEs
AayushSabharwal May 14, 2025
490b8a5
refactor: remove `convert_system`
AayushSabharwal May 14, 2025
0bc9594
test: fix nonlinearsystem tests
AayushSabharwal May 14, 2025
f44de1e
test: fix optimizationsystem tests
AayushSabharwal May 14, 2025
762de22
test: fix sdesystem tests
AayushSabharwal May 14, 2025
82da509
fix: unwrap in `add_toterms!`, don't override existing values
AayushSabharwal May 14, 2025
b696485
fix: unwrap `varmap` and add toterms in `better_varmap_to_vars`
AayushSabharwal May 14, 2025
3d5f2a4
test: allow for `StalledSuccess` retcode in initializationsystem test
AayushSabharwal May 14, 2025
6b1d3f2
test: update mtkparameters tests
AayushSabharwal May 14, 2025
fb14d69
test: update serialization tests
AayushSabharwal May 14, 2025
cd51045
test: fix symbolic indexing interface tests
AayushSabharwal May 14, 2025
b1f092b
fix: fix `Pre` parameter discovery for `AffectSystem`
AayushSabharwal May 14, 2025
7032e3a
fix: fix `compile_condition`, respect `eval_expression` and `eval_mod…
AayushSabharwal May 14, 2025
4c74e4e
fix: respect `eval_expression`, `eval_module` in `compile_equational_…
AayushSabharwal May 14, 2025
ba41cb1
test: fix symbolic event tests
AayushSabharwal May 14, 2025
28d4de5
fix: update CasADi extension to new semantics
AayushSabharwal May 15, 2025
ddd89d7
fix: update InfiniteOpt extension to new semantics
AayushSabharwal May 15, 2025
473ab62
fix: update `optimal_control_interface.jl` to new semantics
AayushSabharwal May 15, 2025
593c7c8
fix: recognize delayed derivatives in `isdelay`
AayushSabharwal May 15, 2025
fcf1b5a
fix: use 2-argument `consolidate` in dynamic optimization tests
AayushSabharwal May 15, 2025
552b4ee
feat: add `expected_scope_depth`
AayushSabharwal Apr 3, 2025
ce959f7
fix: don't `toparam` inside `Initial`
AayushSabharwal Apr 4, 2025
8f71f30
feat: reduce reliance on metadata in `structural_simplify`
AayushSabharwal Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ext/MTKBifurcationKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct ObservableRecordFromSolution{S, T}
# A Vector of pairs (Symbolic => value) with the default values of all system variables and parameters.
subs_vals::T

function ObservableRecordFromSolution(nsys::NonlinearSystem,
function ObservableRecordFromSolution(nsys::System,
plot_var,
bif_idx,
u0_vals,
Expand Down Expand Up @@ -82,7 +82,7 @@ end
### Creates BifurcationProblem Overloads ###

# When input is a NonlinearSystem.
function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
function BifurcationKit.BifurcationProblem(nsys::System,
u0_bif,
ps,
bif_par,
Expand All @@ -92,7 +92,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
jac = true,
kwargs...)
if !ModelingToolkit.iscomplete(nsys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
error("A completed `System` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
end
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
# Creates F and J functions.
Expand Down Expand Up @@ -144,11 +144,11 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
end

# When input is a ODESystem.
function BifurcationKit.BifurcationProblem(osys::ODESystem, args...; kwargs...)
function BifurcationKit.BifurcationProblem(osys::System, args...; kwargs...)
if !ModelingToolkit.iscomplete(osys)
error("A completed `ODESystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
end
nsys = NonlinearSystem([0 ~ eq.rhs for eq in full_equations(osys)],
nsys = System([0 ~ eq.rhs for eq in full_equations(osys)],
unknowns(osys),
parameters(osys);
observed = observed(osys),
Expand Down
76 changes: 37 additions & 39 deletions ext/MTKCasADiDynamicOptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ function (M::MXLinearInterpolation)(τ)
end

"""
CasADiDynamicOptProblem(sys::ODESystem, u0, tspan, p; dt, steps)
CasADiDynamicOptProblem(sys::System, u0, tspan, p; dt, steps)

Convert an ODESystem representing an optimal control system into a CasADi model
Convert an System representing an optimal control system into a CasADi model
for solving using optimization. Must provide either `dt`, the timestep between collocation
points (which, along with the timespan, determines the number of points), or directly
provide the number of points as `steps`.
Expand All @@ -68,10 +68,10 @@ The optimization variables:
- a vector-of-vectors V representing the controls as an interpolation array

The constraints are:
- The set of user constraints passed to the ODESystem via `constraints`
- The set of user constraints passed to the System via `constraints`
- The solver constraints that encode the time-stepping used by the solver
"""
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
dt = nothing,
steps = nothing,
guesses = Dict(), kwargs...)
Expand All @@ -80,7 +80,8 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
t = tspan !== nothing ? tspan[1] : tspan, output_type = MX, kwargs...)

pmap = Dict{Any, Any}(pmap)
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
MTK.evaluate_varmap!(pmap, keys(pmap))
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)

Expand Down Expand Up @@ -143,15 +144,15 @@ function set_casadi_bounds!(model, sys, pmap)
for (i, u) in enumerate(unknowns(sys))
if MTK.hasbounds(u)
lo, hi = MTK.getbounds(u)
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= U.u[i, :])
subject_to!(opti, U.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
subject_to!(opti, Symbolics.fast_substitute(lo, pmap) <= U.u[i, :])
subject_to!(opti, U.u[i, :] <= Symbolics.fast_substitute(hi, pmap))
end
end
for (i, v) in enumerate(MTK.unbound_inputs(sys))
if MTK.hasbounds(v)
lo, hi = MTK.getbounds(v)
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= V.u[i, :])
subject_to!(opti, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
subject_to!(opti, Symbolics.fast_substitute(lo, pmap) <= V.u[i, :])
subject_to!(opti, V.u[i, :] <= Symbolics.fast_substitute(hi, pmap))
end
end
end
Expand All @@ -167,15 +168,15 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
@unpack opti, U, V, tₛ = model

iv = MTK.get_iv(sys)
conssys = MTK.get_constraintsystem(sys)
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
jconstraints = MTK.get_constraints(sys)
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing

stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
cons_unknowns = map(MTK.default_toterm, unknowns(conssys))
cons_dvs, cons_ps = MTK.process_constraint_system(
jconstraints, Set(unknowns(sys)), parameters(sys), iv; validate = false)

auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in cons_dvs])
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints; is_free_t, auxmap)
# Manually substitute fixed-t variables
for (i, cons) in enumerate(jconstraints)
Expand Down Expand Up @@ -207,9 +208,8 @@ end

function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
@unpack opti, U, V, tₛ = model
jcosts = copy(MTK.get_costs(sys))
consolidate = MTK.get_consolidate(sys)
if isnothing(jcosts) || isempty(jcosts)
jcosts = cost(sys)
if Symbolics._iszero(jcosts)
minimize!(opti, MX(0))
return
end
Expand All @@ -218,24 +218,22 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])

jcosts = substitute_casadi_vars(model, sys, pmap, jcosts; is_free_t)
jcosts = substitute_casadi_vars(model, sys, pmap, [jcosts]; is_free_t)[1]
# Substitute fixed-time variables.
for i in 1:length(jcosts)
costvars = MTK.vars(jcosts[i])
for st in costvars
MTK.iscall(st) || continue
x = operation(st)
t = only(arguments(st))
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
cv = U
else
idx = ctidxmap[x(iv)]
cv = V
end
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => cv(t)[idx]))
costvars = MTK.vars(jcosts)
for st in costvars
MTK.iscall(st) || continue
x = operation(st)
t = only(arguments(st))
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
cv = U
else
idx = ctidxmap[x(iv)]
cv = V
end
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => cv(t)[idx]))
end

dt = U.t[2] - U.t[1]
Expand All @@ -249,9 +247,9 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
# Approximate integral as sum.
intmap[int] = dt * tₛ * sum(arg)
end
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
jcosts = MTK.value.(jcosts)
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
jcosts = Symbolics.substitute(jcosts, intmap)
jcosts = MTK.value(jcosts)
minimize!(opti, MX(jcosts))
end

function substitute_casadi_vars(
Expand All @@ -264,20 +262,20 @@ function substitute_casadi_vars(
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]

exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
# tf means different things in different contexts; a [tf] in a cost function
# should be tₛ, while a x(tf) should translate to x[1]
if is_free_t
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
exprs = map(c -> Symbolics.fixpoint_sub(c, free_t_map), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
end

# for variables like x(t)
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
exprs = map(c -> Symbolics.fixpoint_sub(c, whole_interval_map), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
exprs
end

Expand Down
2 changes: 1 addition & 1 deletion ext/MTKFMIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
end

eqs = [observed; diffeqs]
return ODESystem(eqs, t, states, params; parameter_dependencies, defaults = defs,
return System(eqs, t, states, params; parameter_dependencies, defaults = defs,
discrete_events = [instance_management_callback], name, initialization_eqs)
end

Expand Down
58 changes: 30 additions & 28 deletions ext/MTKInfiniteOptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ struct InfiniteOptDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
end

"""
JuMPDynamicOptProblem(sys::ODESystem, u0, tspan, p; dt)
JuMPDynamicOptProblem(sys::System, u0, tspan, p; dt)

Convert an ODESystem representing an optimal control system into a JuMP model
Convert a System representing an optimal control system into a JuMP model
for solving using optimization. Must provide either `dt`, the timestep between collocation
points (which, along with the timespan, determines the number of points), or directly
provide the number of points as `steps`.
Expand All @@ -51,10 +51,10 @@ The optimization variables:
- a vector-of-vectors V representing the controls as an interpolation array

The constraints are:
- The set of user constraints passed to the ODESystem via `constraints`
- The set of user constraints passed to the System via `constraints`
- The solver constraints that encode the time-stepping used by the solver
"""
function MTK.JuMPDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
function MTK.JuMPDynamicOptProblem(sys::System, u0map, tspan, pmap;
dt = nothing,
steps = nothing,
guesses = Dict(), kwargs...)
Expand All @@ -63,24 +63,25 @@ function MTK.JuMPDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)

pmap = Dict{Any, Any}(pmap)
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
MTK.evaluate_varmap!(pmap, keys(pmap))
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)

JuMPDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
end

"""
InfiniteOptDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap; dt)
InfiniteOptDynamicOptProblem(sys::System, u0map, tspan, pmap; dt)

Convert an ODESystem representing an optimal control system into a InfiniteOpt model
Convert System representing an optimal control system into a InfiniteOpt model
for solving using optimization. Must provide `dt` for determining the length
of the interpolation arrays.

Related to `JuMPDynamicOptProblem`, but directly adds the differential equations
of the system as derivative constraints, rather than using a solver tableau.
"""
function MTK.InfiniteOptDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
function MTK.InfiniteOptDynamicOptProblem(sys::System, u0map, tspan, pmap;
dt = nothing,
steps = nothing,
guesses = Dict(), kwargs...)
Expand All @@ -89,7 +90,8 @@ function MTK.InfiniteOptDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)

pmap = Dict{Any, Any}(pmap)
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
MTK.evaluate_varmap!(pmap, keys(pmap))
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)

Expand Down Expand Up @@ -150,29 +152,28 @@ function set_jump_bounds!(model, sys, pmap)
for (i, u) in enumerate(unknowns(sys))
if MTK.hasbounds(u)
lo, hi = MTK.getbounds(u)
set_lower_bound(U[i], Symbolics.fixpoint_sub(lo, pmap))
set_upper_bound(U[i], Symbolics.fixpoint_sub(hi, pmap))
set_lower_bound(U[i], Symbolics.fast_substitute(lo, pmap))
set_upper_bound(U[i], Symbolics.fast_substitute(hi, pmap))
end
end

V = model[:V]
for (i, v) in enumerate(MTK.unbound_inputs(sys))
if MTK.hasbounds(v)
lo, hi = MTK.getbounds(v)
set_lower_bound(V[i], Symbolics.fixpoint_sub(lo, pmap))
set_upper_bound(V[i], Symbolics.fixpoint_sub(hi, pmap))
set_lower_bound(V[i], Symbolics.fast_substitute(lo, pmap))
set_upper_bound(V[i], Symbolics.fast_substitute(hi, pmap))
end
end
end

function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free_t = false)
jcosts = MTK.get_costs(sys)
consolidate = MTK.get_consolidate(sys)
if isnothing(jcosts) || isempty(jcosts)
jcosts = cost(sys)
if Symbolics._iszero(jcosts)
@objective(model, Min, 0)
return
end
jcosts = substitute_jump_vars(model, sys, pmap, jcosts; is_free_t)
jcosts = substitute_jump_vars(model, sys, pmap, [jcosts]; is_free_t)[1]
tₛ = is_free_t ? model[:tf] : 1

# Substitute integral
Expand All @@ -187,17 +188,18 @@ function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free
hi = haskey(pmap, hi) ? 1 : MTK.value(hi)
intmap[int] = tₛ * InfiniteOpt.∫(arg, model[:t], lo, hi)
end
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
@objective(model, Min, consolidate(jcosts))
jcosts = Symbolics.substitute(jcosts, intmap)
@objective(model, Min, MTK.value(jcosts))
end

function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = false)
conssys = MTK.get_constraintsystem(sys)
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
jconstraints = MTK.get_constraints(sys)
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
cons_dvs, cons_ps = MTK.process_constraint_system(
jconstraints, Set(unknowns(sys)), parameters(sys), MTK.get_iv(sys); validate = false)

if is_free_t
for u in MTK.get_unknowns(conssys)
for u in cons_dvs
x = MTK.operation(u)
t = only(arguments(u))
if (MTK.symbolic_type(t) === MTK.NotSymbolic())
Expand All @@ -206,7 +208,7 @@ function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = fals
end
end

auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in cons_dvs])
jconstraints = substitute_jump_vars(model, sys, pmap, jconstraints; auxmap, is_free_t)

# Substitute to-term'd variables
Expand Down Expand Up @@ -235,25 +237,25 @@ function substitute_jump_vars(model, sys, pmap, exprs; auxmap = Dict(), is_free_
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]

exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
if is_free_t
tf = model[:tf]
free_t_map = Dict([[x(tf) => U[i](1) for (i, x) in enumerate(x_ops)];
[c(tf) => V[i](1) for (i, c) in enumerate(c_ops)]])
exprs = map(c -> Symbolics.fixpoint_sub(c, free_t_map), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
end

# for variables like x(t)
whole_interval_map = Dict([[v => U[i] for (i, v) in enumerate(sts)];
[v => V[i] for (i, v) in enumerate(cts)]])
exprs = map(c -> Symbolics.fixpoint_sub(c, whole_interval_map), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)

# for variables like x(1.0)
fixed_t_map = Dict([[x_ops[i] => U[i] for i in 1:length(U)];
[c_ops[i] => V[i] for i in 1:length(V)]])

exprs = map(c -> Symbolics.fixpoint_sub(c, fixed_t_map), exprs)
exprs = map(c -> Symbolics.fast_substitute(c, fixed_t_map), exprs)
exprs
end

Expand Down
Loading
Loading