From 4d1c2f62bb89433d25177671d568b382682fec9b Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 7 May 2025 23:16:10 -0700 Subject: [PATCH 1/7] Use getsym to access modified and observed values from the integrator --- src/systems/imperative_affect.jl | 17 +++++------------ test/symbolic_events.jl | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index b0742e70a7..554b7235a5 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -189,18 +189,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. else zeros(sz) end - obs_fun = build_explicit_observed_function( - sys, Symbolics.scalarize.(obs_exprs); - mkarray = (es, _) -> MakeTuple(es)) - obs_sym_tuple = (obs_syms...,) + geto_funs = NamedTuple{(obs_syms...,)}((getsym.((sys,), obs_exprs)...,)) # okay so now to generate the stuff to assign it back into the system + getm_funs = NamedTuple{(mod_syms...,)}((getsym.((sys,), mod_exprs)...,)) + mod_pairs = mod_exprs .=> mod_syms mod_names = (mod_syms...,) - mod_og_val_fun = build_explicit_observed_function( - sys, Symbolics.scalarize.(first.(mod_pairs)); - mkarray = (es, _) -> MakeTuple(es)) - upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing @@ -212,12 +207,10 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. let user_affect = func(affect), ctx = context(affect) function (integ) # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens - modvals = mod_og_val_fun(integ.u, integ.p, integ.t) - upd_component_array = NamedTuple{mod_names}(modvals) + upd_component_array = _generated_readback(integ, getm_funs) # update the observed values - obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun( - integ.u, integ.p, integ.t)) + obs_component_array = _generated_readback(integ, geto_funs) # let the user do their thing upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 5c0a2ee7fc..d207222cd0 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1461,3 +1461,26 @@ end sys = structural_simplify(sys) sol = solve(ODEProblem(sys, [], (0.0, 1.0)), Tsit5()) end + +@testset "Tuples in ImperativeAffect arguments" begin + @mtkmodel ImperativeAffectTupleMWE begin + @parameters begin + y(t) = 1.0 + end + @variables begin + x(t) = 0.0 + end + @equations begin + D(x) ~ y + end + @continuous_events begin + (x ~ 0.5) => ModelingToolkit.ImperativeAffect( + observed = (; mypars = (x, 2 * x)), modified = (; y)) do m, o, c, i + return (; y = 2 * o.mypars[1] + o.mypars[2]) + end + end + end + @mtkbuild sys = ImperativeAffectTupleMWE() + prob = ODEProblem(sys, [], (0.0, 1.0)) + sol = solve(prob, Tsit5()) +end From 51acbfb34e0c88ae5fffbac23fa153c6de492cdc Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 7 May 2025 23:17:14 -0700 Subject: [PATCH 2/7] Forgot readback --- src/systems/callbacks.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a58ff3f8ec..eb490a5417 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -970,6 +970,14 @@ end end) end +@generated function _generated_readback(integ, getters::NamedTuple{NS1, <:Tuple}) where {NS1} + getter_exprs = [] + for name in NS1 + push!(getter_exprs, :($name = getters.$name(integ))) + end + return :((; $(getter_exprs...))) +end + function check_assignable(sys, sym) if symbolic_type(sym) == ScalarSymbolic() is_variable(sys, sym) || is_parameter(sys, sym) From a1287dab1bdc01af19a0edaf2ceb6ca0b19f5683 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 7 May 2025 23:26:36 -0700 Subject: [PATCH 3/7] Avoid a writeback if the user affect returns nothing --- src/systems/imperative_affect.jl | 9 +++++---- test/symbolic_events.jl | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 554b7235a5..72e037e021 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -216,10 +216,11 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) - - for idx in save_idxs - SciMLBase.save_discretes!(integ, idx) + if !isnothing(upd_vals) + _generated_writeback(integ, upd_funs, upd_vals) + for idx in save_idxs + SciMLBase.save_discretes!(integ, idx) + end end end end diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index d207222cd0..0cc141a668 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1484,3 +1484,27 @@ end prob = ODEProblem(sys, [], (0.0, 1.0)) sol = solve(prob, Tsit5()) end + +@testset "ImperativeAffect skips writing back when nothing is returned" begin + @mtkmodel ImperativeAffectTupleMWE begin + @parameters begin + y(t) = 1.0 + end + @variables begin + x(t) = 0.0 + end + @equations begin + D(x) ~ y + end + @continuous_events begin + (x ~ 0.5) => ModelingToolkit.ImperativeAffect( + observed = (; mypars = (x, 2 * x)), modified = (; y)) do m, o, c, i + return nothing + end + end + end + @mtkbuild sys = ImperativeAffectTupleMWE() + prob = ODEProblem(sys, [], (0.0, 1.0)) + sol = solve(prob, Tsit5()) + @test length(sol[sys.y]) == 1 +end From 72dcda6cc8882abd60dcaafb3175486ae5a6c2f7 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 7 May 2025 23:30:31 -0700 Subject: [PATCH 4/7] Rename the MWE in the nothing writeback case --- test/symbolic_events.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 0cc141a668..a409d0c942 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1486,7 +1486,7 @@ end end @testset "ImperativeAffect skips writing back when nothing is returned" begin - @mtkmodel ImperativeAffectTupleMWE begin + @mtkmodel ImperativeAffectWriteNothingMWE begin @parameters begin y(t) = 1.0 end @@ -1503,7 +1503,7 @@ end end end end - @mtkbuild sys = ImperativeAffectTupleMWE() + @mtkbuild sys = ImperativeAffectWriteNothingMWE() prob = ODEProblem(sys, [], (0.0, 1.0)) sol = solve(prob, Tsit5()) @test length(sol[sys.y]) == 1 From 744f652a6e1b5d77085b9eb1410420e2510018eb Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 7 May 2025 23:33:16 -0700 Subject: [PATCH 5/7] Document the nothing return handling --- src/systems/imperative_affect.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 72e037e021..c209c9931e 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -26,7 +26,7 @@ The NamedTuple returned from `f` includes the values to be written back to the s Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in `modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it -in the returned tuple, in which case the associated field will not be updated. +in the returned tuple, in which case the associated field will not be updated. To avoid writing back, either return `nothing` or an empty named tuple. """ @kwdef struct ImperativeAffect f::Any From e6f1f06206821c9afa6c0b52dd68f6e130d5afc3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 8 May 2025 13:36:17 +0530 Subject: [PATCH 6/7] fix: don't pass `dt` as an observed to `ImperativeAffect` in MTKFMIExt --- ext/MTKFMIExt.jl | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/ext/MTKFMIExt.jl b/ext/MTKFMIExt.jl index 5cfe9a82ef..79410d9fc9 100644 --- a/ext/MTKFMIExt.jl +++ b/ext/MTKFMIExt.jl @@ -261,7 +261,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6, # use `ImperativeAffect` for instance management here cb_observed = (; inputs = __mtk_internal_x, params = copy(params), - t, wrapper, dt = communication_step_size) + t, wrapper) cb_modified = (;) # modify the outputs if present if symbolic_type(__mtk_internal_o) != NotSymbolic() @@ -272,11 +272,12 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6, cb_modified = (cb_modified..., states = __mtk_internal_u) end initialize_affect = MTK.ImperativeAffect(fmiCSInitialize!; observed = cb_observed, - modified = cb_modified, ctx = _functor) + modified = cb_modified, ctx = (_functor, communication_step_size)) finalize_affect = MTK.FunctionalAffect(fmiFinalize!, [], [wrapper], []) # the callback affect performs the stepping step_affect = MTK.ImperativeAffect( - fmiCSStep!; observed = cb_observed, modified = cb_modified, ctx = _functor) + fmiCSStep!; observed = cb_observed, modified = cb_modified, + ctx = (_functor, communication_step_size)) instance_management_callback = MTK.SymbolicDiscreteCallback( communication_step_size, step_affect; initialize = initialize_affect, finalize = finalize_affect, reinitializealg = reinitializealg @@ -775,7 +776,8 @@ the value being the output vector if the FMU has output variables. `o` should co Initializes the FMU. Only for use with CoSimulation FMUs. """ -function fmiCSInitialize!(m, o, ctx::FMI2CSFunctor, integrator) +function fmiCSInitialize!(m, o, ctx, integrator) + functor::FMI2CSFunctor, dt = ctx states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params @@ -787,10 +789,10 @@ function fmiCSInitialize!(m, o, ctx::FMI2CSFunctor, integrator) instance = get_instance_CS!(wrapper, states, inputs, params, t) if isdefined(m, :states) - @statuscheck FMI.fmi2GetReal!(instance, ctx.state_value_references, m.states) + @statuscheck FMI.fmi2GetReal!(instance, functor.state_value_references, m.states) end if isdefined(m, :outputs) - @statuscheck FMI.fmi2GetReal!(instance, ctx.output_value_references, m.outputs) + @statuscheck FMI.fmi2GetReal!(instance, functor.output_value_references, m.outputs) end return m @@ -804,13 +806,13 @@ periodically to communicte with the CoSimulation FMU. Has the same requirements `fmiCSInitialize!` for `m` and `o`, with the addition that `o` should have a key `:dt` with the value being the communication step size. """ -function fmiCSStep!(m, o, ctx::FMI2CSFunctor, integrator) +function fmiCSStep!(m, o, ctx, integrator) + functor::FMI2CSFunctor, dt = ctx wrapper = o.wrapper states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params t = o.t - dt = o.dt instance = get_instance_CS!(wrapper, states, inputs, params, integrator.t) if !isempty(inputs) @@ -820,10 +822,10 @@ function fmiCSStep!(m, o, ctx::FMI2CSFunctor, integrator) @statuscheck FMI.fmi2DoStep(instance, integrator.t - dt, dt, FMI.fmi2True) if isdefined(m, :states) - @statuscheck FMI.fmi2GetReal!(instance, ctx.state_value_references, m.states) + @statuscheck FMI.fmi2GetReal!(instance, functor.state_value_references, m.states) end if isdefined(m, :outputs) - @statuscheck FMI.fmi2GetReal!(instance, ctx.output_value_references, m.outputs) + @statuscheck FMI.fmi2GetReal!(instance, functor.output_value_references, m.outputs) end return m @@ -874,7 +876,8 @@ end """ $(TYPEDSIGNATURES) """ -function fmiCSInitialize!(m, o, ctx::FMI3CSFunctor, integrator) +function fmiCSInitialize!(m, o, ctx, integrator) + functor::FMI3CSFunctor, dt = ctx states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params @@ -885,10 +888,11 @@ function fmiCSInitialize!(m, o, ctx::FMI3CSFunctor, integrator) end instance = get_instance_CS!(wrapper, states, inputs, params, t) if isdefined(m, :states) - @statuscheck FMI.fmi3GetFloat64!(instance, ctx.state_value_references, m.states) + @statuscheck FMI.fmi3GetFloat64!(instance, functor.state_value_references, m.states) end if isdefined(m, :outputs) - @statuscheck FMI.fmi3GetFloat64!(instance, ctx.output_value_references, m.outputs) + @statuscheck FMI.fmi3GetFloat64!( + instance, functor.output_value_references, m.outputs) end return m @@ -897,13 +901,13 @@ end """ $(TYPEDSIGNATURES) """ -function fmiCSStep!(m, o, ctx::FMI3CSFunctor, integrator) +function fmiCSStep!(m, o, ctx, integrator) + functor::FMI3CSFunctor, dt = ctx wrapper = o.wrapper states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params t = o.t - dt = o.dt instance = get_instance_CS!(wrapper, states, inputs, params, integrator.t) if !isempty(inputs) @@ -921,10 +925,11 @@ function fmiCSStep!(m, o, ctx::FMI3CSFunctor, integrator) @assert earlyReturn[] == FMI.fmi3False if isdefined(m, :states) - @statuscheck FMI.fmi3GetFloat64!(instance, ctx.state_value_references, m.states) + @statuscheck FMI.fmi3GetFloat64!(instance, functor.state_value_references, m.states) end if isdefined(m, :outputs) - @statuscheck FMI.fmi3GetFloat64!(instance, ctx.output_value_references, m.outputs) + @statuscheck FMI.fmi3GetFloat64!( + instance, functor.output_value_references, m.outputs) end return m From 8caf3e4a7f0855770e855cc5288cfea938be7e63 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 8 May 2025 15:06:26 +0530 Subject: [PATCH 7/7] fix: fix dispatch in modified affect functions --- ext/MTKFMIExt.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/MTKFMIExt.jl b/ext/MTKFMIExt.jl index 79410d9fc9..2e9fa70de2 100644 --- a/ext/MTKFMIExt.jl +++ b/ext/MTKFMIExt.jl @@ -776,8 +776,8 @@ the value being the output vector if the FMU has output variables. `o` should co Initializes the FMU. Only for use with CoSimulation FMUs. """ -function fmiCSInitialize!(m, o, ctx, integrator) - functor::FMI2CSFunctor, dt = ctx +function fmiCSInitialize!(m, o, ctx::Tuple{FMI2CSFunctor, Vararg}, integrator) + functor, dt = ctx states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params @@ -806,8 +806,8 @@ periodically to communicte with the CoSimulation FMU. Has the same requirements `fmiCSInitialize!` for `m` and `o`, with the addition that `o` should have a key `:dt` with the value being the communication step size. """ -function fmiCSStep!(m, o, ctx, integrator) - functor::FMI2CSFunctor, dt = ctx +function fmiCSStep!(m, o, ctx::Tuple{FMI2CSFunctor, Vararg}, integrator) + functor, dt = ctx wrapper = o.wrapper states = isdefined(m, :states) ? m.states : () inputs = o.inputs @@ -876,8 +876,8 @@ end """ $(TYPEDSIGNATURES) """ -function fmiCSInitialize!(m, o, ctx, integrator) - functor::FMI3CSFunctor, dt = ctx +function fmiCSInitialize!(m, o, ctx::Tuple{FMI3CSFunctor, Vararg}, integrator) + functor, dt = ctx states = isdefined(m, :states) ? m.states : () inputs = o.inputs params = o.params @@ -901,8 +901,8 @@ end """ $(TYPEDSIGNATURES) """ -function fmiCSStep!(m, o, ctx, integrator) - functor::FMI3CSFunctor, dt = ctx +function fmiCSStep!(m, o, ctx::Tuple{FMI3CSFunctor, Vararg}, integrator) + functor, dt = ctx wrapper = o.wrapper states = isdefined(m, :states) ? m.states : () inputs = o.inputs