DiscreteCallback takes a break: discrete callback shows counterintuitive behaviour in DifferentialEquations.jl

Hello,
I have a question about the behaviour of callbacks using DifferentialEquations.jl, because I find the behaviour of DiscreteCallbacks counterintuitive.
Consider the following code, i.e. an adaptation of this:

using DifferentialEquations
using JumpProcesses
using Plots
using Random 
using StatsBase

u_thr = 2.
u_base = 1.

f(u, p, t) = 1.01 * u

function continuous_condition(u, t, integrator)
    u - u_thr
end

function continuous_affect!(integrator)
    integrator.u = u_base
    push!(continuous_crosses, integrator.t)
end

function discrete_condition(u, t, integrator)
    u .> u_thr
end

function discrete_affect!(integrator)
    integrator.u = u_base
    push!(discrete_crosses, integrator.t)
end

#jump process
rate(u, p, t) = 1. #in Hz
function affect!(integrator)
    integrator.u += sample([0., 1.])
end

function find_greater_equal_thresh(sol)
    spikes = findall(sol.u .>= u_thr)
    spiketimes = sol.t[spikes]
    return spikes, spiketimes
end


begin
    continuous_crosses = Float64[]
    discrete_crosses = Float64[]

    Random.seed!(2)

    spike_cb = ContinuousCallback(continuous_condition, continuous_affect!; rootfind = SciMLBase.RightRootFind);
    discrete_condition_cb = DiscreteCallback(discrete_condition, discrete_affect!);
    cbs = CallbackSet(spike_cb, discrete_condition_cb);

    spike_cb = ContinuousCallback(continuous_condition, continuous_affect!; rootfind = SciMLBase.RightRootFind);
    discrete_condition_cb = DiscreteCallback(discrete_condition, discrete_affect!)
    cbs = CallbackSet(spike_cb, discrete_condition_cb)
    u0 = 3 / 2
    tspan = (0.0, 1.5)
    prob = ODEProblem(f, u0, tspan)
    jump = ConstantRateJump(rate, affect!)
    jump_prob = JumpProblem(prob, Direct(), jump)
    sol = solve(jump_prob, ImplicitEuler(), callback = cbs, save_everystep = false);
    spikes, spiketimes= find_greater_equal_thresh(sol)
    
    #plot solution
    p = plot(sol.t, sol.u, linewidth = 2, title = "Continuous vs discrete callbacks",
        xlabel = "Time (s)", alpha = 0.3)
    hline!([u_thr], label = "Threshold", linestyle = :dash, linewidth = 0.7)
    vline!(discrete_crosses, label = "Discrete callback", linestyle = :dash, linewidth = 0.7)
    vline!(continuous_crosses, label = "Continuous callback", linestyle = :dash, linewidth = 0.7)
    display(p)
    savefig(p, "08042024/outputs/callback-differences.pdf")
end

begin
    println("number of discrete callbacks: ", length(discrete_crosses))#1
    println("number of continuous callbacks: ", length(continuous_crosses))#2
    println("number of time frames spent at or above threshold: ", length(findall(sol.u .>= u_thr)))#3

    println("times when u is above threshold: ", sol.t[findall(sol.u .>= u_thr)])#4
    println("timesteps when u is above threshold: ", findall(sol.u .>= u_thr))#5
    
    println("times of discrete callbacks: ", discrete_crosses)#6
    println("times of continuous callbacks: ", continuous_crosses)#7
end

The code produces the following graph:
callback-differences

The graph suggests that there are 3 crossings of the threshold, and therefore there should be 3 timesteps at which u > u_thr holds. However, this is not the case because there are 5!
Here are the outputs of the print statements:

number of discrete callbacks: 2 #1

number of continuous callbacks: 1 #2

number of time frames spent at or above threshold: 5 #3

times when u is above threshold: [0.007120977006082045, 0.007120977006082045, 0.5984251769043971, 0.5984251769043971, 1.2748637298022953] #4

timesteps when u is above threshold: [3, 4, 7, 8, 10] #5

times of discrete callbacks: [0.007120977006082045, 0.5984251769043971] #6

times of continuous callbacks: [1.2748637298022953] #7

The print statements say that despite there being only three callback calls (three crossings) the dynamical variable is not reset immediately after a DiscreteCallback call. Instead, you see that a DiscreteCallback call always has two time steps at which u is above threshold (i.e. fulfils the callback condition to be reset).
Instead, when ContinuousCallback is called, the dynamical variable u is immediately reset to u_base. After a DiscreteCallback the dynamical variable stays above threshold instead of being reset at the next timestep. Why is this?

On the one hand I feel that this is fine detail, but this plays a large role when analysing the solution of an ODEProblem after the fact. If I want to find the timepoints where u crossed threshold, I have to consider if the crossing was called by a ContinuousCallback or a DiscreteCallback, which I cannot know after the code is run.

Note: I need both a continuous and a discrete callback here because there are two types of dynamics of u (I think): if u crosses threshold because of the normal evolution of the differential equation, a ContinuousCallback will detect its crossing the threshold. Instead, if the crossing is caused by a jump from the JumpProcess, I need a DiscreteCallback.

1 Like

No it doesn’t? You have a lot of things going on here, so it’s easiest to just set save_positions = (false,false) and save_everystep=true, to see directly the time step values.

spike_cb = ContinuousCallback(continuous_condition, continuous_affect!; rootfind = SciMLBase.RightRootFind, save_positions = (false,false));
discrete_condition_cb = DiscreteCallback(discrete_condition, discrete_affect!, save_positions = (false,false))
cbs = CallbackSet(spike_cb, discrete_condition_cb)
u0 = 3 / 2
tspan = (0.0, 1.5)
prob = ODEProblem(f, u0, tspan)
jump = ConstantRateJump(rate, affect!)
jump_prob = JumpProblem(prob, Direct(), jump, save_positions = (false,false))
sol = solve(jump_prob, ImplicitEuler(), callback = cbs, save_everystep = true);
spikes, spiketimes= find_greater_equal_thresh(sol)

When I do that I see:

julia> sol.t
88-element Vector{Float64}:
 0.0
 9.809495801065397e-6
 1.1771394961278477e-5
 3.139038656340927e-5
 ⋮
 1.4332244667557321
 1.4656928472826676
 1.4721865233880547
 1.5

julia> sol.u
88-element Vector{Float64}:
 1.5
 1.5000148615333806
 1.5000178338459464
 1.5000475575605854
 ⋮
 1.1760287101851352
 1.2159018580374032
 1.223929135412532
 1.2593050432952182
julia> println(sol.u)
[1.5, 1.5000148615333806, 1.5000178338459464, 1.5000475575605854, 1.5003448536167279, 1.50332371684972, 1.5084169578785362, 1.5108407473183476, 1.0000777344613467, 1.000093281595307, 1.0002487771077493, 1.0018061532799305, 1.0055486048710025, 1.0143653933771393, 1.025963923741493, 1.0503650425838422, 1.0690389306209556, 1.1160109430894622, 1.135168744096239, 1.1672789457486037, 1.1888497854348854, 1.2371989533444483, 1.2596209130012974, 1.298456736034086, 1.3233705826385884, 1.3751524208001553, 1.4013703170466076, 1.44722112928813, 1.4760920970047944, 1.5328656767765925, 1.5634742304786355, 1.6168709212042984, 1.6503696290158318, 1.7135112700935637, 1.7491871172333195, 1.8334811173852628, 1.0016721333743062, 1.003723486902015, 1.024665013300298, 1.0288705084883354, 1.0727177091489573, 1.0815594301263731, 1.1154275745500473, 1.1228953843894307, 1.1611865536840231, 1.1688956302876146, 1.2088922087132148, 1.2169448089600152, 1.2585276238502385, 1.2668995096629088, 1.3102123762719498, 1.3189326040733527, 1.3640138269190756, 1.3730900661828156, 1.4200258364471632, 1.429475457708902, 1.478336245115153, 1.4881734287417667, 1.5390406100480114, 1.5492817427103187, 1.6022366972875233, 1.6128981618568194, 1.6680270041511724, 1.6791261386205747, 1.7365179307685399, 1.7480726658299222, 1.807820366810815, 1.8198494180212514, 1.8820497143124644, 1.8945725475732536, 1.959326123957346, 1.9723630101883562, 2.0000000000000004, 1.0003929696982747, 1.0008667912744005, 1.005627544978201, 1.008030425761198, 1.0326460045834192, 1.037592703620255, 1.0782144448425102, 1.0864004748412424, 1.121581338694256, 1.1286619311337707, 1.1680895841818448, 1.1760287101851352, 1.2159018580374032, 1.223929135412532, 1.2593050432952182]

So every step value is clearly below 2 so that statement is just wrong.

Why is it wrong? Well let’s turn save positions back on for all but the continuous callback, and save_everystep false. Then I see:

julia> println(sol.t)
[0.0, 0.007120977006082045, 0.007120977006082045, 0.007120977006082045, 0.007120977006082045, 0.5984251769043971, 0.5984251769043971, 0.5984251769043971, 0.5984251769043971, 1.5]

julia> println(sol.u)
[1.5, 1.5108407473183476, 2.5108407473183476, 2.5108407473183476, 1.0, 1.8334811173852628, 2.8334811173852628, 2.8334811173852628, 1.0, 1.2593050432952182]

And finally, save_positions only on the discrete callback:

spike_cb = ContinuousCallback(continuous_condition, continuous_affect!; rootfind = SciMLBase.RightRootFind, save_positions = (false,false));
discrete_condition_cb = DiscreteCallback(discrete_condition, discrete_affect!, save_positions = (true,true))
cbs = CallbackSet(spike_cb, discrete_condition_cb)
u0 = 3 / 2
tspan = (0.0, 1.5)
prob = ODEProblem(f, u0, tspan)
jump = ConstantRateJump(rate, affect!)
jump_prob = JumpProblem(prob, Direct(), jump, save_positions = (false,false))
sol = solve(jump_prob, ImplicitEuler(), callback = cbs, save_everystep = false);
spikes, spiketimes= find_greater_equal_thresh(sol)
julia> sol.t
6-element Vector{Float64}:
 0.0
 0.007120977006082045
 0.007120977006082045
 0.5984251769043971
 0.5984251769043971
 1.5

julia> println(sol.u)
[1.5, 2.5108407473183476, 1.0, 2.8334811173852628, 1.0, 1.2593050432952182]

And there we see what’s going on. In the time series:

julia> println(sol.t)
[0.0, 0.007120977006082045, 0.007120977006082045, 0.007120977006082045, 0.007120977006082045, 0.5984251769043971, 0.5984251769043971, 0.5984251769043971, 0.5984251769043971, 1.5]

julia> println(sol.u)
[1.5, 1.5108407473183476, 2.5108407473183476, 2.5108407473183476, 1.0, 1.8334811173852628, 2.8334811173852628, 2.8334811173852628, 1.0, 1.2593050432952182]

At 0.007120977006082045 you have a jump trigger, which with save_positions=(true,true) will save just before the jump (1.5108407473183476) and just after the jump (2.5108407473183476). Then the DiscreteCallback is resolved, and with save_positions = (true,true) it will save just before the jump 2.5108407473183476 and just after the jump 1.0. And notice all of that takes place at the same time, 0.007120977006082045, 0.007120977006082045, 0.007120977006082045, 0.007120977006082045.

So it looks like it’s doing exactly what you specified in the model, and the dynamical variable does not stay above the threshold instead of being reset at the next time step.

You can if you setup a save array that holds this information.

1 Like

Thanks for the quick response, indeed I was wrong.
Amazing support, thank you!