Hello,
I have a question about the behaviour of callbacks using DifferentialEquations.jl, because I find the behaviour of DiscreteCallbacks counterintuitive.
Consider the following code, i.e. an adaptation of this:
using DifferentialEquations
using JumpProcesses
using Plots
using Random
using StatsBase
u_thr = 2.
u_base = 1.
f(u, p, t) = 1.01 * u
function continuous_condition(u, t, integrator)
u - u_thr
end
function continuous_affect!(integrator)
integrator.u = u_base
push!(continuous_crosses, integrator.t)
end
function discrete_condition(u, t, integrator)
u .> u_thr
end
function discrete_affect!(integrator)
integrator.u = u_base
push!(discrete_crosses, integrator.t)
end
#jump process
rate(u, p, t) = 1. #in Hz
function affect!(integrator)
integrator.u += sample([0., 1.])
end
function find_greater_equal_thresh(sol)
spikes = findall(sol.u .>= u_thr)
spiketimes = sol.t[spikes]
return spikes, spiketimes
end
begin
continuous_crosses = Float64[]
discrete_crosses = Float64[]
Random.seed!(2)
spike_cb = ContinuousCallback(continuous_condition, continuous_affect!; rootfind = SciMLBase.RightRootFind);
discrete_condition_cb = DiscreteCallback(discrete_condition, discrete_affect!);
cbs = CallbackSet(spike_cb, discrete_condition_cb);
spike_cb = ContinuousCallback(continuous_condition, continuous_affect!; rootfind = SciMLBase.RightRootFind);
discrete_condition_cb = DiscreteCallback(discrete_condition, discrete_affect!)
cbs = CallbackSet(spike_cb, discrete_condition_cb)
u0 = 3 / 2
tspan = (0.0, 1.5)
prob = ODEProblem(f, u0, tspan)
jump = ConstantRateJump(rate, affect!)
jump_prob = JumpProblem(prob, Direct(), jump)
sol = solve(jump_prob, ImplicitEuler(), callback = cbs, save_everystep = false);
spikes, spiketimes= find_greater_equal_thresh(sol)
#plot solution
p = plot(sol.t, sol.u, linewidth = 2, title = "Continuous vs discrete callbacks",
xlabel = "Time (s)", alpha = 0.3)
hline!([u_thr], label = "Threshold", linestyle = :dash, linewidth = 0.7)
vline!(discrete_crosses, label = "Discrete callback", linestyle = :dash, linewidth = 0.7)
vline!(continuous_crosses, label = "Continuous callback", linestyle = :dash, linewidth = 0.7)
display(p)
savefig(p, "08042024/outputs/callback-differences.pdf")
end
begin
println("number of discrete callbacks: ", length(discrete_crosses))#1
println("number of continuous callbacks: ", length(continuous_crosses))#2
println("number of time frames spent at or above threshold: ", length(findall(sol.u .>= u_thr)))#3
println("times when u is above threshold: ", sol.t[findall(sol.u .>= u_thr)])#4
println("timesteps when u is above threshold: ", findall(sol.u .>= u_thr))#5
println("times of discrete callbacks: ", discrete_crosses)#6
println("times of continuous callbacks: ", continuous_crosses)#7
end
The code produces the following graph:
The graph suggests that there are 3 crossings of the threshold, and therefore there should be 3 timesteps at which u > u_thr
holds. However, this is not the case because there are 5!
Here are the outputs of the print statements:
number of discrete callbacks: 2 #1
number of continuous callbacks: 1 #2
number of time frames spent at or above threshold: 5 #3
times when u is above threshold: [0.007120977006082045, 0.007120977006082045, 0.5984251769043971, 0.5984251769043971, 1.2748637298022953] #4
timesteps when u is above threshold: [3, 4, 7, 8, 10] #5
times of discrete callbacks: [0.007120977006082045, 0.5984251769043971] #6
times of continuous callbacks: [1.2748637298022953] #7
The print statements say that despite there being only three callback calls (three crossings) the dynamical variable is not reset immediately after a DiscreteCallback
call. Instead, you see that a DiscreteCallback
call always has two time steps at which u
is above threshold (i.e. fulfils the callback condition to be reset).
Instead, when ContinuousCallback
is called, the dynamical variable u
is immediately reset to u_base
. After a DiscreteCallback
the dynamical variable stays above threshold instead of being reset at the next timestep. Why is this?
On the one hand I feel that this is fine detail, but this plays a large role when analysing the solution of an ODEProblem
after the fact. If I want to find the timepoints where u
crossed threshold, I have to consider if the crossing was called by a ContinuousCallback
or a DiscreteCallback
, which I cannot know after the code is run.
Note: I need both a continuous and a discrete callback here because there are two types of dynamics of u
(I think): if u
crosses threshold because of the normal evolution of the differential equation, a ContinuousCallback
will detect its crossing the threshold. Instead, if the crossing is caused by a jump from the JumpProcess, I need a DiscreteCallback
.