Discrete callback + Saving callback for changing parameters in DifferentialEquations doesn't work as intended

Hi,

I have been using DifferentialEquations.jl to solve a set of around 250 coupled SDEs that used to be solved with custom Euler-Maruyama. In
the model this is simulating there is a parameter, an adjacency matrix, that is constantly
changing. The changes in the parameter are recorded.

I have been trying to use a SavingCallback together with DiscreteCallback in a
CallbackSet to reproduce a parameter change that used to be implemented as a function
call after every step of the Euler-Maruyama integration procedure. Even though the
DiscreteCallback is implemented to be run after every integration step, it appears like
it acts even before the first step, messing the parameter it changes.

Here’s an example of what I mean. The animation is comparing the DiffEq implementation vs.
the legacy implementation. The parameter of interest determines the color of the dots. The
initial conditions are exactly the same.

comparison of simulation methods. Right is ground truth

The expected behaviour is for all points on both simulations to coincide exactly at the
beginning, and for the switches in color to match as time advances. Another clue that there is something wrong is the fact that points of a color should always be attracted to their “corner” like on the right, but on the diffeq version sometimes red dots move to the “blue” corner, and so on.

The (considerably) simplified code of both solutions look like this:

The drift function and the parameter changing function:

function point_drift(X, B)
  return sum(B .* X; dims = 2)
end

# For illustrative purposes
function param_switch(X, B, dt)
  avg = sum(B; dims = 2) .* dt
  return B ./ avg
end

The legacy, Euler-Maruyama solver taking fixed size steps

function simulate!(X, B; Nt=200, dt=0.01)
  sol = zeros(...)
  params = zeros(...)
  t_points = 1:(Nt - 1)
  for i in t_points:
    force_x = point_drift(X, B)
    sol[i] .= sol + dt * force_x + sigma * sqrt(dt) * randn()

    # Act on the parameters
    new_B = param_switch(X, B, dt)
    params[i] = new_B
    B .= B
  end
end

The DiffEq implementation uses the same functions

function drift(du, u, p, t)
  # Indexing made simpler for illustrative purposes
  du[:] .= point_drift(X, p.B)

  return nothing
end

function noise(du, u, p, t)
  # Indexing made simpler for illustrative purposes
  du[:] .= p.σ

  return nothing
end

function parameter_switch_affect!(integrator)
  dt = get_proposed_dt(integrator)
  u = integrator.u
  p = integrator.p

  new_B = param_switch(u, p.B, dt)
  integrator.p.B = new_B
  return nothing
end

# Callback functions
true_condition = function (u, t, integrator)
    return true
end

param_switching_callback = DiscreteCallback(true_condition,
                                              parameter_switch_affect!;
                                              save_positions=(true, false))

save_B = function (u, t, integrator)
    return integrator.p.B
end

The callbacks are combined while preparing the problem

B_cache = SavedValues(Float64, BitMatrix)
saving_callback = SavingCallback(save_B, B_cache; saveat = 0.01, save_end=false, save_start=true)
cbs = callbackset(param_switching_callback, saving_callback)

prob = SDEProblem(drift, noise, u0, (B = BitMatrix(250, 250)))
sol = solve(prob, SRIW1(); callback = cbs, alg_hints=:additive, save_everystep = true)

Without the Saving Callback I have no way of comparing the solutions, so there’s no way of only using the discrete callback. Any hints on why the switching callback is not behaving as the legacy implementation?

Why not just print the time the callback is called at in the condition, i.e.

true_condition(u,p,integ) = (@show t; true)

That will show you if it really is being called at t = 0 as you think.

I didn’t mention it but I tried it and it seems to be called at the same time a step is taken. What puzzles me is why, even with the discrete timesteps, it acts completely different

Sorry, I don’t know the SDE solvers too well. What if you try Euler-Maruyama with the same fixed time step as your own implementation?

I just tried and the result is the same. Here is an image that shows how the color “start out” wrong again (this one has non-zero noise so the points jiggle)
e-maruyama

Sorry, without a simple MWE one can run it is hard to diagnose what is going on. The only other thing I can think of is whether you should be returning a copy of B in the saving callback?

This does not make much algorithmic sense. It’s not convergent to anything? If you’re using dt from an adaptive method to change a parameter, then the parameter simply is not well-defined. In some runs it could be zero at a point, in others it could be huge. There is no definition of the underlying SDE for the code you’ve given.

Your fixed time step code only works because it’s equivalent to a model where you have B ./ (0.01 * sum(B; dims = 2)) as your parameter. Even in an ODE sense that would be not well-defined if that parameter is changing randomly.

Secondly, you didn’t share your actual driver script so it’s impossible to tell what you’re doing here, but if I could guess at what you’re doing, if you run:

prob = SDEProblem(drift, noise, u0, (B = BitMatrix(250, 250)))
sol = solve(prob, SRIW1(); callback = cbs, alg_hints=:additive, save_everystep = true)
sol = solve(prob, SRIW1(); callback = cbs, alg_hints=:additive, save_everystep = true)

you’ll notice that the second solve starts with the modified parameter at the end of the first. This is because you’re directly modifying the parameter p. If you don’t want that, you’d need to make sure you refresh that parameter:

prob = SDEProblem(drift, noise, u0, (B = BitMatrix(250, 250)))
sol = solve(prob, SRIW1(); callback = cbs, alg_hints=:additive, save_everystep = true)
prob = SDEProblem(drift, noise, u0, (B = BitMatrix(250, 250)))
sol = solve(prob, SRIW1(); callback = cbs, alg_hints=:additive, save_everystep = true)

(note: since your SDE is actually additive noise, don’t choose SRIW1, choose SRA1, SRA3, or SOSRA)

Note also that BitMatrix(250, 250) is not legal Julia code:

julia> BitMatrix(250, 250)
ERROR: MethodError: no method matching BitMatrix(::Int64, ::Int64)

Closest candidates are:
  BitArray{N}(::UndefInitializer, ::Int64...) where N
   @ Base bitarray.jl:28
  BitArray{N}(::UndefInitializer, ::Integer...) where N
   @ Base bitarray.jl:70
  BitArray{N}(::Any) where N
   @ Base bitarray.jl:578

Stacktrace:
 [1] top-level scope
   @ REPL[30]:1

So again the code you show here is clearly not what you’re actually running and I need to make a guess at the difference. I would venture to guess the difference is that you’re actually using BitMatrix(undef, 250, 250), but note that has undefined memory and is thus not equivalent to zeros(...) in your simulate! code.

julia> BitMatrix(undef, 250, 250)
250×250 BitMatrix:
 0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  1  1  0  …  1  0  1  0  0  0  0  0  1  1  1  0  1  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  0  1  0  0  0  0  0  0  0  1     0  1  0  0  0  0  0  0  0  0  0  1  1  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  1  1  0  0  0  0  0  0  0  0     0  0  1  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0
 1  0  0  0  0  0  0  0  1  1  0  0  0  0  0  0  0  0  1     1  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0
 ⋮              ⋮              ⋮              ⋮           ⋱           ⋮              ⋮              ⋮           
 0  0  0  0  0  0  1  1  1  0  0  0  0  0  0  0  1  0  0     1  0  0  0  0  0  0  0  1  0  1  0  0  0  0  0  0  0
 0  0  0  0  1  1  1  1  0  0  0  0  0  0  0  0  0  1  1     0  0  0  0  0  0  0  0  0  0  1  1  0  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  1  0     1  0  0  0  0  0  0  0  0  1  1  0  0  0  0  0  0  0

It’s completely random because it’s just unchanged values from a bit slab allocated to you. If you actually want that as zeros, you’d need to so false .* BitMatrix(undef, 250, 250).

So it would be helpful if we could have a runnable piece of code here that compares the two, but my guess is that it’s the last problem, i.e. that you’re using an undef initialization of a bitmatrix assuming that it’s equivalent to all zeros, but it’s almost certainly not going to be all zeros.

I might have simplified too much to the point where the meaning was lost. The SDE (in fact, system of) comes from a paper of an ABM where there are different kinds of agents.

There are 3 SDEs for the different types of agents:

  1. dx_i(t) = F_i (x, y, z, t) \, dt + \sigma \, dW_i (t), where x_i is a 2-vector and
F_i(x,y,z,t) = \frac{a}{\sum_{j^\prime}w_{ij^\prime}(t)} \sum_{j=1}^N w_{ij}(t) (x_j(t) - x_i(t)) + \frac{b}{\sum_{m'}B_{im'}(t)} \sum_{m=1}^M B_{im}(t) (y_m(t) - x_i(t)) + \frac{c}{\sum_{l'}C_{il'}(t)} \sum_{l=1}^L C_{il}(t) (z_l(t) - x_i(t)),

implementation
3. \Gamma dy_m(t) = (\widetilde{x_m}(t) - y_m (t)) \, dt + \widetilde{\sigma} \, d \widetilde{W}_m (t), where \widetilde{x}_m is the average position of the agents connected to y_m. Implementation
4. \gamma dz_{\ell} (t) = (\widehat{x_{\ell}} (t) - z_{\ell}(t)) \, dt + \widehat{\sigma} d \widetilde{W_\ell} where \widehat{x_{\ell}} is the average position of the agents connected to z_\ell. Implementation

In the model the agents switch connections between y_m and y_n with individual rates that change over time. These are the changes I am trying to model with the callback. Originally, they are created by simulating a Poisson point process with the switching rates, which is why the dt is used. That Poisson process is implemented as a function that takes in parameters and returns a new BitMatrix.

The adjacency matrix is initialized by splitting the agents by quadrant, and assigning them to the y_m that lives in that quadrant. I omitted that part for simplicity. The actual code looks like:

 C = _orthantize(X) |> BitMatrix
# Orthantize returns a matrix the same shape as X where each row corresponds to an agent
# and only 1 columns ==1. If the 1st column ==1, the agent is on quadrant 1 and so on.

The adjacency matrices are exactly the same on the first step for both simulations, which are started with a fresh copy of the model. This is how the problem is initialized and solved:

function build_sdeproblem(omp::OpinionModelProblem{T,D}, tspan::Tuple{T,T}) where {T,D}
    mp = omp.p
    X, Y, Z, A, B, C = omp # A, B and C are BitMatrix, X, Y, Z Array{Float64}
    L, M, n, η, a, b, c, σ, σ̂, σ̃, γ, Γ = omp.p
    # Stack all important parameters to be fed to the integrator
    P = (L=L, M=M, n=n, η=η, a=a, b=b, c=c, σ=σ, σ̂=σ̂, σ̃=σ̃, γ=γ, Γ=Γ, A=A, B=B, C=C,
         p=mp)

    u₀ = vcat(X, Y, Z)

    return SDEProblem(drift, noise, u₀, tspan, P)
end

# Solve
function simulate!(omp::OpinionModelProblem{T,D}, tspan::Tuple{T,T}; dt::T=0.01,
                   seed=MersenneTwister())::ModelSimulation where {T,D}
    # Seeding the RNG
    Random.seed!(seed)

    # Defining the callbacks
    C_cache = SavedValues(Float64, BitMatrix)
    saving_callback = SavingCallback(save_C, C_cache; saveat=dt, save_end=false,
                                     save_start=true)
    cbs = CallbackSet(influencer_switching_callback, saving_callback)

    diffeq_prob = build_sdeproblem(omp, tspan)
    diffeq_sol = solve(diffeq_prob, EM(), dt=dt; callback=cbs, alg_hints=:additive,
                       save_everystep=true)

    # TODO: Maybe use the retcode from diffeq to warn here.

    return ModelSimulation{T,D,DiffEqSolver}(diffeq_sol, omp.domain, C_cache,
                                             diffeq_prob.p.p)
end

Yes but the in the equations you wrote down it’s dt^2, which is not matching the Poisson point process form which is dt. Remember there’s an implicit dt in the integration process, so if you multiply by dt you get a dt^2 term. It’s a bit hard to see the whole model here, but I’m weary that you’re accidentally applying that twice.

Would you suggest simulating the Poisson process differently on the DiffEq side? That would eliminate the dependence on dt

I’m struggling to see where this product would be sneaking in. The switching function call is the same on the hand-coded Euler-Maruyama and the callback. They both call this function:

function switch_influencer(C::Bm, X::T, Z::T, rates::U,
                           dt) where {Bm<:BitMatrix,T,U<:AbstractVecOrMat}
    L, n = size(Z, 1), size(X, 1)

    # rates = influencer_switch_rates(X, Z, B, C, η)
    RC = copy(C)

    for j in 1:n
        r = rand()
        lambda = sum(rates[j, :])
        if r < 1 - exp(-lambda * dt)
            p = rates[j, :] / lambda
            r2 = rand()
            k = 1
            while sum(p[1:k]) < r2
                k += 1
            end

            RC[j, :] = zeros(L)
            RC[j, k] = 1
        end
    end

Okay I see what you’re trying to do now. There’s generally two ways to resolve a Poisson process tied to a differential equation:

  1. Use a continuous callback based on the rate to resolve every single reaction. This is exact but it’s probably too costly for your case.
  2. Resolve the discrete changes using a tau-leaping like approach.

It seems you’re trying to do 2 and mix it with a high order adaptive SDE solver, i.e. you’re letting the SDE solver choose a dt, take a step, and then have a true callback so that after that step you resolve the Poisson process by sampling based on that dt and tau-leap the discrete part. Is that a correct description of what you’re doing?

First question then, if you did a periodic callback with a dt to the periodicity matching your fixed time step code, does it “work”? I’m still not quite clear on what you’re calling the working vs not working behavior here, but I’m starting to parse your case.

yes! that’s precisely it, you put it better words that I did

When I force the CallbackSet to act every dt=0.01 I get the first figure I posted, the comparison of the two processes side by side

The end goal is to be able to reproduce the legacy simulation in the conditions of zero noise and the same parameters. So for the DiffEq solution to work the switches (represented as points changing in color) would match perfectly with the legacy simulation. There are several differences between the 2 at the moment:

  1. On the very first frame of the animation (i.e. the first point of the solution) the colors are already different, even though the starting configurations are the same.
  2. Points of a color should all move in the direction of their corner (e.g. the red corner is the upper right corner), but on the DiffEq version sometimes a red dot moves to the blue corner and so on

I am fairly certain that the point of failure is in the callback. Without the effects of color switching, both solutions behave the same