Why are only ContinuousCallbacks exposed in ODESystem?

DifferentialEquations support both discrete and continuous callbacks - but it seems only the latter kind is available when using ModelingToolkit:

function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
                   ...
                   continuous_events = nothing,
                   checks = true)

This makes it difficult to simulate e.g. interventions at specified times as shown in Event Handling and Callback Functions · DifferentialEquations.jl.

Is there a reason for this omission?

Thanks,
DD

Because ModelingToolkit, as it’s mentioned in the JuliaCon workshop, is still a work in progress. Some just needs to expose the discrete callbacks.

Bummer.
Seems there’s no way to add a non-trivial control component.

You can grow the ecosystem by contributing to it. :slight_smile:

2 Likes

I’ve actually played with this idea… but I’m not even sure what is the correct interface to support these callbacks (can/should they be represented as equations?).

Of course, this is just an excuse: ModelingToolkit is way too complex to just pop-up and contribute a feature.

You’d be surprised. It’s almost just a copy paste of the continuous callback code. Did you check the tests for whether it’s just not documented yet? I know @baggepinnen and @sharanry we’re doing a few adjacent things

I think events are limited to continuous events with equality conditions only currently, and only work for ODESystems. So there are several ways they could be generalized. Probably a good first PR would be just adding continuous events as they currently exist to some other system types (like SDESystem). That would basically just be copying the current ODESystem event code and help one understand how it works.

What do you mean?

Looking at the code, there’s no mention of other types of callbacks.
Periodic callbacks are used, to integrate with DiscreteSystems.

However, I think the problem is more to do with how affects are specified (regardless of discrete vs. continuous): they are currently implemented as “equations” (which are really assignment expressions). This makes sense in that it simplifies access to the state of the system, but it also severely limits what can be done on a callback.

What’s needed is a way to run arbitrary Julia code as affect (yes, you could @register_symbolic a function, but it has to be a pure function - so can’t keep state). The interface for such an affect function should allow access to the ODESystem’s state - i.e. variables and parameters and should allow modifying them. DiffEqs exposes the integrator but this would be too low-level: instead of a u vector, kwargs?

DD

Coming to think about it, the problem is less with continuous vs. discrete events, and more how affects are specified.
I could easily implement a periodic event with continuous callbacks, after all (e.g. with sin), but the types of manipulations I could do when they trigger are pretty limited. It would be nice if we could run arbitrary Julia code as affect.
That said, I’m not sure about the interface for such a function, which I guess is one reason why equation-like affects where chosen.

DD

OK,
I’ve implemented a proof-of-concept supporting periodic callbacks. Patch below.

It allows me to write code such as:

@variables t x(t)=1 v(t)=0
D = Differential(t)

function affect(u, p, t)
    u["x(t)"]  = -u["x(t)"]
end

period = 1.0

@named ball = ODESystem([D(x) ~ v
                         D(v) ~ -9.8], t, periodic_events = period => affect)

This is not a good interface…

Patch

--- a/src/ModelingToolkit.jl
+++ b/src/ModelingToolkit.jl
@@ -211,4 +211,6 @@ export modelingtoolkitize
 export @variables, @parameters
 export @named, @nonamespace, @namespace, extend, compose
 
+export MyPeriodicCallback, MyPeriodicCallbacks, MY_NULL_AFFECT, periodic_events
+
 end # module
diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl
index e2cba876f..b05d53d33 100644
--- a/src/systems/abstractsystem.jl
+++ b/src/systems/abstractsystem.jl
@@ -217,6 +217,50 @@ namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallbac
                                                                                                                namespace_equation.(affect_equations(cb),
                                                                                                                                    (s,)))
 
+const MY_NULL_AFFECT = (args...; kwargs...) -> nothing
+
+struct MyPeriodicCallback
+    Δt::Number
+    affect::Function
+    function MyPeriodicCallback(Δt::Number, affect = MY_NULL_AFFECT)
+        new(Δt, affect)
+    end # Default affect to nothing
+end
+
+function Base.:(==)(e1::MyPeriodicCallback, e2::MyPeriodicCallback)
+    isequal(e1.Δt, e2.Δt) && isequal(e1.affect, e2.affect)
+end
+Base.isempty(cb::MyPeriodicCallback) = iszero(cb.Δt)
+function Base.hash(cb::MyPeriodicCallback, s::UInt)
+    s = foldr(hash, cb.Δt, init = s)
+    foldr(hash, cb.affect, init = s)
+end
+
+MyPeriodicCallback(p::Pair) = MyPeriodicCallback(p[1], p[2])
+MyPeriodicCallback(cb::MyPeriodicCallback) = cb # passthrough
+
+MyPeriodicCallbacks(cb::MyPeriodicCallback) = [cb]
+MyPeriodicCallbacks(cbs::Vector{<:MyPeriodicCallback}) = cbs
+MyPeriodicCallbacks(cbs::Vector) = MyPeriodicCallback.(cbs)
+function MyPeriodicCallbacks(Δt::Number)
+    MyPeriodicCallbacks(MyCallback(Δt))
+end
+function MyPeriodicCallbacks(others)
+    MyPeriodicCallbacks(MyPeriodicCallback(others))
+end
+MyPeriodicCallbacks(::Nothing) = MyPeriodicCallbacks(0)
+
+period(cb::MyPeriodicCallback) = cb.Δt
+function period(cbs::Vector{<:MyPeriodicCallback})
+    reduce(vcat, [period(cb) for cb in cbs])
+end
+affect_function(cb::MyPeriodicCallback) = cb.affect
+
+function affect_functions(cbs::Vector{MyPeriodicCallback})
+    reduce(vcat, [affect_function(cb) for cb in cbs])
+end
+
+
 for prop in [:eqs
              :noiseeqs
              :iv
@@ -513,12 +557,25 @@ function continuous_events(sys::AbstractSystem)
     systems = get_systems(sys)
     cbs = [obs;
            reduce(vcat,
-                  (map(o -> namespace_equation(o, s), continuous_events(s))
+                  (map(o -> namespace_equation(o, s), my_events(s))
                    for s in systems),
                   init = SymbolicContinuousCallback[])]
     filter(!isempty, cbs)
 end
 
+function periodic_events(sys::AbstractSystem)
+    obs = get_periodic_events(sys)
+    filter(!isempty, obs)
+    systems = get_systems(sys)
+    cbs = [obs;
+           reduce(vcat,
+                  (periodic_events(s)
+                   for s in systems),
+                  init = MyPeriodicCallback[])]
+    filter(!isempty, cbs)
+end
+
+
 Base.@deprecate default_u0(x) defaults(x) false
 Base.@deprecate default_p(x) defaults(x) false
 function defaults(sys::AbstractSystem)
diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl
index f18f2e387..1d71ec778 100644
--- a/src/systems/diffeqs/abstractodesystem.jl
+++ b/src/systems/diffeqs/abstractodesystem.jl
@@ -245,6 +245,64 @@ function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys),
     end
 end
 
+function generate_periodic_callbacks(sys::ODESystem, dvs = states(sys),
+    ps = parameters(sys); kwargs...)
+    cbs = periodic_events(sys)
+    isempty(cbs) && return nothing
+    generate_periodic_callbacks(cbs, sys, dvs, ps; kwargs...)
+end
+
+function generate_periodic_callbacks(cbs, sys::ODESystem, dvs = states(sys),
+                                                ps = parameters(sys); kwargs...)
+    Δts = map(cb -> cb.Δt, cbs)
+    num_periods = length.(Δts)
+    (isempty(Δts) || sum(num_periods) == 0) && return nothing
+
+    affect_functions = map(cbs) do cb
+        af_f = affect_function(cb)
+        affect = compile_affect(af_f, sys, dvs, ps; kwargs...)
+    end
+
+    PeriodicCallback.(affect_functions, Δts)
+end
+
+function compile_affect(cb::MyPeriodicCallback, args...; kwargs...)
+    compile_affect(affect_function(cb), args...; kwargs...)
+end
+
+function compile_affect(cb::Function, sys, dvs, ps; kwargs...)
+    u = map(x -> time_varying_as_func(value(x), sys), dvs)
+    p = map(x -> time_varying_as_func(value(x), sys), ps)
+
+    # incredibly stupid implementation
+
+    vs = string.(tosymbol.(u))
+    ps = string.(tosymbol.(p))
+
+    rvdict = Dict(zip(vs, 1:length(vs)))
+    rpdict = Dict(zip(ps, 1:length(ps)))
+    let vs = vs, ps = ps, rvdict=rvdict, rpdict=rpdict
+        function (integ)
+            u = isnothing(integ.u) ? Dict() : Dict(zip(vs, integ.u))
+            p = isnothing(integ.p) ? Dict() : Dict(zip(ps, integ.p))
+            cb(u, p, integ.t)
+
+            for (ũ,v) in u
+                integ.u[rvdict[ũ]] = v
+            end
+            for (p̃,v) in p
+                integ.p[rvdict[p̃]] = v
+            end
+        end
+    end
+end
+
 function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
     compile_affect(affect_equations(cb), args...; kwargs...)
 end
@@ -283,6 +341,15 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
     end
 end
 
+
+function generate_rootfinding_callback(sys::ODESystem, dvs = states(sys),
+                                       ps = parameters(sys); kwargs...)
+    cbs = continuous_events(sys)
+    isempty(cbs) && return nothing
+    generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
+end
+
+
 function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
     # if something is not x(t) (the current state)
     # but is `x(t-1)` or something like that, pass in `x` as a callable function rather
@@ -749,8 +816,16 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
     else
         event_cb = nothing
     end
+
+    if has_periodic_events(sys)
+        periodic_event_cb = generate_periodic_callbacks(sys; kwargs...)
+    else
+        periodic_event_cb = nothing
+    end
+
     difference_cb = has_difference ? generate_difference_cb(sys; kwargs...) : nothing
     cb = merge_cb(event_cb, difference_cb)
+    cb = reduce(merge_cb, periodic_event_cb; init=cb)
     cb = merge_cb(cb, callback)
 
     kwargs = filter_kwargs(kwargs)
diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl
index 2a307cbe3..349b97e2d 100644
--- a/src/systems/diffeqs/odesystem.jl
+++ b/src/systems/diffeqs/odesystem.jl
@@ -101,6 +101,11 @@ struct ODESystem <: AbstractODESystem
     The integrator will use root finding to guarantee that it steps at each zero crossing.
     """
     continuous_events::Vector{SymbolicContinuousCallback}
+    # Added
+    """
+    periodic_events: A `Vector{MyPeriodicCallback}` that model periodic callbacks.
+    """
+    periodic_events::Vector{MyPeriodicCallback}
     """
     tearing_state: cache for intermediate tearing state
     """
@@ -110,9 +115,10 @@ struct ODESystem <: AbstractODESystem
     """
     substitutions::Any
 
+
     function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
                        jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
-                       torn_matching, connector_type, connections, preface, events,
+                       torn_matching, connector_type, connections, preface, events, myevents,
                        tearing_state = nothing, substitutions = nothing;
                        checks::Bool = true)
         if checks
@@ -124,7 +130,7 @@ struct ODESystem <: AbstractODESystem
         end
         new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
             ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
-            connector_type, connections, preface, events, tearing_state, substitutions)
+            connector_type, connections, preface, events, myevents, tearing_state, substitutions)
     end
 end
 
@@ -139,6 +145,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
                    connector_type = nothing,
                    preface = nothing,
                    continuous_events = nothing,
+                   periodic_events = nothing,
                    checks = true)
     name === nothing &&
         throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@@ -172,9 +179,11 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
         throw(ArgumentError("System names must be unique."))
     end
     cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
+
+    mycallbacks = MyPeriodicCallbacks(periodic_events)
     ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
               ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
-              connector_type, nothing, preface, cont_callbacks, checks = checks)
+              connector_type, nothing, preface, cont_callbacks, mycallbacks, checks = checks)
 end
 
 function ODESystem(eqs, iv = nothing; kwargs...)
@@ -244,6 +253,7 @@ function flatten(sys::ODESystem, noeqs = false)
                          parameters(sys),
                          observed = observed(sys),
                          continuous_events = continuous_events(sys),
+                         periodic_events = periodic_events(sys),
                          defaults = defaults(sys),
                          name = nameof(sys),
                          checks = false)
@@ -257,6 +267,11 @@ get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events
 has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
 get_callback(prob::ODEProblem) = prob.kwargs[:callback]
 
+get_periodic_events(sys::AbstractSystem) = Number[]
+get_periodic_events(sys::AbstractODESystem) = getfield(sys, :periodic_events)
+has_periodic_events(sys::AbstractSystem) = isdefined(sys, :periodic_events)
+
+
 """
 $(SIGNATURES)

Oh you want periodic then you can mix in discrete time update equation operators with continuous time ones. There’s an example in the tests:

It’s very experimental right now so it’s purposefully not documented at this time. But, what it needs to be “finished” is just more testing and polish, so if you’re willing to slam at it and open an issue if it doesn’t work, then feel free.

Thanks, but as I’ve said above, discrete vs. continuous is not the real issue: the real issue is whether one can have arbitrary affect functions.

I.e. something like:

@named de = ODESystem(eqs, t, [x,y], [a,b,c,d], ..., (cond, affect!, state), ...)

Where cond and affect are functions and state is an (optional) state that is provided to these functions. I.e.

struct MyState
   c::Controller
end

function affect!(u,p,t, state)
   update!(state.c, u, p, t)
   if should_switch(state.c)
      p.x, p.y = p.y, p.x
   end
end