Skip to content

Use getsym instead of an explicitly generated function and avoid writeback if nothing is returned #3610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,14 @@ end
end)
end

@generated function _generated_readback(integ, getters::NamedTuple{NS1, <:Tuple}) where {NS1}
getter_exprs = []
for name in NS1
push!(getter_exprs, :($name = getters.$name(integ)))
end
return :((; $(getter_exprs...)))
end

function check_assignable(sys, sym)
if symbolic_type(sym) == ScalarSymbolic()
is_variable(sys, sym) || is_parameter(sys, sym)
Expand Down
17 changes: 5 additions & 12 deletions src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
else
zeros(sz)
end
obs_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(obs_exprs);
mkarray = (es, _) -> MakeTuple(es))
obs_sym_tuple = (obs_syms...,)
geto_funs = NamedTuple{(obs_syms...,)}((getsym.((sys,), obs_exprs)...,))

# okay so now to generate the stuff to assign it back into the system
getm_funs = NamedTuple{(mod_syms...,)}((getsym.((sys,), mod_exprs)...,))

mod_pairs = mod_exprs .=> mod_syms
mod_names = (mod_syms...,)
mod_og_val_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(first.(mod_pairs));
mkarray = (es, _) -> MakeTuple(es))

upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))

if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
Expand All @@ -212,12 +207,10 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
let user_affect = func(affect), ctx = context(affect)
function (integ)
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
upd_component_array = NamedTuple{mod_names}(modvals)
upd_component_array = _generated_readback(integ, getm_funs)

# update the observed values
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
integ.u, integ.p, integ.t))
obs_component_array = _generated_readback(integ, geto_funs)

# let the user do their thing
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)
Expand Down
23 changes: 23 additions & 0 deletions test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1461,3 +1461,26 @@ end
sys = structural_simplify(sys)
sol = solve(ODEProblem(sys, [], (0.0, 1.0)), Tsit5())
end

@testset "Tuples in ImperativeAffect arguments" begin
@mtkmodel ImperativeAffectTupleMWE begin
@parameters begin
y(t) = 1.0
end
@variables begin
x(t) = 0.0
end
@equations begin
D(x) ~ y
end
@continuous_events begin
(x ~ 0.5) => ModelingToolkit.ImperativeAffect(
observed = (; mypars = (x, 2 * x)), modified = (; y)) do m, o, c, i
return (; y = 2 * o.mypars[1] + o.mypars[2])
end
end
end
@mtkbuild sys = ImperativeAffectTupleMWE()
prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Tsit5())
end
Loading