OK,
I’ve implemented a proof-of-concept supporting periodic callbacks. Patch below.
It allows me to write code such as:
@variables t x(t)=1 v(t)=0
D = Differential(t)
function affect(u, p, t)
u["x(t)"] = -u["x(t)"]
end
period = 1.0
@named ball = ODESystem([D(x) ~ v
D(v) ~ -9.8], t, periodic_events = period => affect)
This is not a good interface…
Patch
--- a/src/ModelingToolkit.jl
+++ b/src/ModelingToolkit.jl
@@ -211,4 +211,6 @@ export modelingtoolkitize
export @variables, @parameters
export @named, @nonamespace, @namespace, extend, compose
+export MyPeriodicCallback, MyPeriodicCallbacks, MY_NULL_AFFECT, periodic_events
+
end # module
diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl
index e2cba876f..b05d53d33 100644
--- a/src/systems/abstractsystem.jl
+++ b/src/systems/abstractsystem.jl
@@ -217,6 +217,50 @@ namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallbac
namespace_equation.(affect_equations(cb),
(s,)))
+const MY_NULL_AFFECT = (args...; kwargs...) -> nothing
+
+struct MyPeriodicCallback
+ Δt::Number
+ affect::Function
+ function MyPeriodicCallback(Δt::Number, affect = MY_NULL_AFFECT)
+ new(Δt, affect)
+ end # Default affect to nothing
+end
+
+function Base.:(==)(e1::MyPeriodicCallback, e2::MyPeriodicCallback)
+ isequal(e1.Δt, e2.Δt) && isequal(e1.affect, e2.affect)
+end
+Base.isempty(cb::MyPeriodicCallback) = iszero(cb.Δt)
+function Base.hash(cb::MyPeriodicCallback, s::UInt)
+ s = foldr(hash, cb.Δt, init = s)
+ foldr(hash, cb.affect, init = s)
+end
+
+MyPeriodicCallback(p::Pair) = MyPeriodicCallback(p[1], p[2])
+MyPeriodicCallback(cb::MyPeriodicCallback) = cb # passthrough
+
+MyPeriodicCallbacks(cb::MyPeriodicCallback) = [cb]
+MyPeriodicCallbacks(cbs::Vector{<:MyPeriodicCallback}) = cbs
+MyPeriodicCallbacks(cbs::Vector) = MyPeriodicCallback.(cbs)
+function MyPeriodicCallbacks(Δt::Number)
+ MyPeriodicCallbacks(MyCallback(Δt))
+end
+function MyPeriodicCallbacks(others)
+ MyPeriodicCallbacks(MyPeriodicCallback(others))
+end
+MyPeriodicCallbacks(::Nothing) = MyPeriodicCallbacks(0)
+
+period(cb::MyPeriodicCallback) = cb.Δt
+function period(cbs::Vector{<:MyPeriodicCallback})
+ reduce(vcat, [period(cb) for cb in cbs])
+end
+affect_function(cb::MyPeriodicCallback) = cb.affect
+
+function affect_functions(cbs::Vector{MyPeriodicCallback})
+ reduce(vcat, [affect_function(cb) for cb in cbs])
+end
+
+
for prop in [:eqs
:noiseeqs
:iv
@@ -513,12 +557,25 @@ function continuous_events(sys::AbstractSystem)
systems = get_systems(sys)
cbs = [obs;
reduce(vcat,
- (map(o -> namespace_equation(o, s), continuous_events(s))
+ (map(o -> namespace_equation(o, s), my_events(s))
for s in systems),
init = SymbolicContinuousCallback[])]
filter(!isempty, cbs)
end
+function periodic_events(sys::AbstractSystem)
+ obs = get_periodic_events(sys)
+ filter(!isempty, obs)
+ systems = get_systems(sys)
+ cbs = [obs;
+ reduce(vcat,
+ (periodic_events(s)
+ for s in systems),
+ init = MyPeriodicCallback[])]
+ filter(!isempty, cbs)
+end
+
+
Base.@deprecate default_u0(x) defaults(x) false
Base.@deprecate default_p(x) defaults(x) false
function defaults(sys::AbstractSystem)
diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index f18f2e387..1d71ec778 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -245,6 +245,64 @@ function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys),
end
end
+function generate_periodic_callbacks(sys::ODESystem, dvs = states(sys),
+ ps = parameters(sys); kwargs...)
+ cbs = periodic_events(sys)
+ isempty(cbs) && return nothing
+ generate_periodic_callbacks(cbs, sys, dvs, ps; kwargs...)
+end
+
+function generate_periodic_callbacks(cbs, sys::ODESystem, dvs = states(sys),
+ ps = parameters(sys); kwargs...)
+ Δts = map(cb -> cb.Δt, cbs)
+ num_periods = length.(Δts)
+ (isempty(Δts) || sum(num_periods) == 0) && return nothing
+
+ affect_functions = map(cbs) do cb
+ af_f = affect_function(cb)
+ affect = compile_affect(af_f, sys, dvs, ps; kwargs...)
+ end
+
+ PeriodicCallback.(affect_functions, Δts)
+end
+
+function compile_affect(cb::MyPeriodicCallback, args...; kwargs...)
+ compile_affect(affect_function(cb), args...; kwargs...)
+end
+
+function compile_affect(cb::Function, sys, dvs, ps; kwargs...)
+ u = map(x -> time_varying_as_func(value(x), sys), dvs)
+ p = map(x -> time_varying_as_func(value(x), sys), ps)
+
+ # incredibly stupid implementation
+
+ vs = string.(tosymbol.(u))
+ ps = string.(tosymbol.(p))
+
+ rvdict = Dict(zip(vs, 1:length(vs)))
+ rpdict = Dict(zip(ps, 1:length(ps)))
+ let vs = vs, ps = ps, rvdict=rvdict, rpdict=rpdict
+ function (integ)
+ u = isnothing(integ.u) ? Dict() : Dict(zip(vs, integ.u))
+ p = isnothing(integ.p) ? Dict() : Dict(zip(ps, integ.p))
+ cb(u, p, integ.t)
+
+ for (ũ,v) in u
+ integ.u[rvdict[ũ]] = v
+ end
+ for (p̃,v) in p
+ integ.p[rvdict[p̃]] = v
+ end
+ end
+ end
+end
+
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
compile_affect(affect_equations(cb), args...; kwargs...)
end
@@ -283,6 +341,15 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
end
end
+
+function generate_rootfinding_callback(sys::ODESystem, dvs = states(sys),
+ ps = parameters(sys); kwargs...)
+ cbs = continuous_events(sys)
+ isempty(cbs) && return nothing
+ generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
+end
+
+
function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
# if something is not x(t) (the current state)
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
@@ -749,8 +816,16 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
else
event_cb = nothing
end
+
+ if has_periodic_events(sys)
+ periodic_event_cb = generate_periodic_callbacks(sys; kwargs...)
+ else
+ periodic_event_cb = nothing
+ end
+
difference_cb = has_difference ? generate_difference_cb(sys; kwargs...) : nothing
cb = merge_cb(event_cb, difference_cb)
+ cb = reduce(merge_cb, periodic_event_cb; init=cb)
cb = merge_cb(cb, callback)
kwargs = filter_kwargs(kwargs)
diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 2a307cbe3..349b97e2d 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -101,6 +101,11 @@ struct ODESystem <: AbstractODESystem
The integrator will use root finding to guarantee that it steps at each zero crossing.
"""
continuous_events::Vector{SymbolicContinuousCallback}
+ # Added
+ """
+ periodic_events: A `Vector{MyPeriodicCallback}` that model periodic callbacks.
+ """
+ periodic_events::Vector{MyPeriodicCallback}
"""
tearing_state: cache for intermediate tearing state
"""
@@ -110,9 +115,10 @@ struct ODESystem <: AbstractODESystem
"""
substitutions::Any
+
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
- torn_matching, connector_type, connections, preface, events,
+ torn_matching, connector_type, connections, preface, events, myevents,
tearing_state = nothing, substitutions = nothing;
checks::Bool = true)
if checks
@@ -124,7 +130,7 @@ struct ODESystem <: AbstractODESystem
end
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
- connector_type, connections, preface, events, tearing_state, substitutions)
+ connector_type, connections, preface, events, myevents, tearing_state, substitutions)
end
end
@@ -139,6 +145,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
connector_type = nothing,
preface = nothing,
continuous_events = nothing,
+ periodic_events = nothing,
checks = true)
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@@ -172,9 +179,11 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
throw(ArgumentError("System names must be unique."))
end
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
+
+ mycallbacks = MyPeriodicCallbacks(periodic_events)
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
- connector_type, nothing, preface, cont_callbacks, checks = checks)
+ connector_type, nothing, preface, cont_callbacks, mycallbacks, checks = checks)
end
function ODESystem(eqs, iv = nothing; kwargs...)
@@ -244,6 +253,7 @@ function flatten(sys::ODESystem, noeqs = false)
parameters(sys),
observed = observed(sys),
continuous_events = continuous_events(sys),
+ periodic_events = periodic_events(sys),
defaults = defaults(sys),
name = nameof(sys),
checks = false)
@@ -257,6 +267,11 @@ get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events
has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
+get_periodic_events(sys::AbstractSystem) = Number[]
+get_periodic_events(sys::AbstractODESystem) = getfield(sys, :periodic_events)
+has_periodic_events(sys::AbstractSystem) = isdefined(sys, :periodic_events)
+
+
"""
$(SIGNATURES)