Improving the speed of large population simulations in DifferentialEquations.jl

Hi,

Apologies if this isn’t the correct forum for this style of question, but I think it could be viewed as a performance problem.

Sensitive/Resistant Birth Death Simulation

I have a birth-death simulation I am running in Julia. Specifically, I model sensitive and resistant cells growing over time.
I currently model the system as a Jump Problem using the ‘MassAction’ Jumps. For now, to keep the simulation simple, the populations evolve independently. Here is the code I use to set up these independent birth-death processes for sensitive (‘S’) and resistant populations (‘R’):

    # Define the rate identities
    rateidxs = [1, 2, 3, 4]
    # Define the reactant stoichiometry
    reactant_stoich = 
    [
        [1 => 1], # b for S cells
        [1 => 1], # d for S cells
        [2 => 1], # b for R cells
        [2 => 1], # d for R cells      
    ]
    # Define the net stoichiometry
    net_stoich =
    [
        [1 => 1],          # b for S cells
        [1 => -1],         # d for S cells
        [2 => 1],          # b for R cells
        [2 => -1]          # d for R cells
    ]

    # Formulate as a mass action jump problem. 
    mass_act_jump = MassActionJump(reactant_stoich, net_stoich; 
    param_idxs=rateidxs)

    prob = DiscreteProblem(u₀, tspan, p)

    jump_prob = JumpProblem(prob, Direct(), mass_act_jump,
                            save_positions= (false, false))

Additionally (not shown here) I use callbacks to do the following:

  • Switch between treated and non-treated environments periodically (I change the death rates of the sensitive cells
  • Check for extinction (if nS+nR == 0)
  • Check for some carrying capacity being reached (if nS+nR >= Nmax)

Parameter Inference with ABC

This all runs as expected, and for a starting population of 1000 cells, and given the other parameters (time of growth, max population size, ratio of sensitive:resistant cells etc.) this takes around 1s to run. I use the simulation to infer various parameters via ABC and I can recover the true values of parameters I care about.

The problem is that the runtime of this simulation scales linearly with the number of cells. I eventually aim to use this for systems with millions of cells, and because of the large number of simulations one has to run for ABC, the time it takes to run becomes prohibitive.

Speeding up the Simulation

Are there any creative ways I might significantly improve the speed of this simulation? - And by significantly, I mean break the current linear relationship between cell number and simulation run time.

Thoughts I have had so far:

  • It feels wasteful to be simulating each individual cell (which I think is what is going on under the hood in these JumpProblems? Which explains the way the sim scales with number of cells?) when I only care about 2 types… can I leverage the fact only have two species to speedup the simulation somehow?
  • Instead of a jump problem I could use a stochastic differential equation (SDE)? - the problem here is that, because the population sizes often start at/approach 0, the SDE solution becomes unstable.
  • Because once the population sizes reach a threshold the stochastic component of the model becomes insignificant, could I do some adaptive switching from a JumpProcess to an ODE when some population size is reached?
  • Related to the previous point, is there a clever way I could combine a JumpProcess and an ODE to be constantly switching depending on what the population size was? - and would there be a way to still have the two models ‘talk to eachother’ (say if I wanted to allow cells to transition from one type to another, but one sub-population was currently being modelled via a JumpProcess whilst the other via an ODE?)

Thanks for any help, and apologies again if this is too vague a question for the Performance topic - please point me in the right direction if so.

1 Like

ABC is slow because it’s a derivative-free approach. This is a perfect test case for trying something derivative-based with GitHub - gaurav-arya/StochasticAD.jl: Research package for automatic differentiation of programs containing discrete randomness. . Let’s chat with Gaurav: we could probably use this as an example for an upcoming paper. This would likely be the biggest improvement as changing to a derivative-based approach like HMC would probably reduce the required number of samples by a ton.

If I’m reading this correctly, there’s only like 4 states? That’s probably not enough to specialize on.

That’s expected: the CLE approximation is only for n-> infinity.

Yes indeed! Though once you make one of the values continuous you do have to be careful that the rates are at least approximately constant between jumps, or change to VariableRateJumps (which there’s now much faster tools for). But yes a good way to do this is a hybrid simulation where things change to jumps when n is sufficiently small.

Use a callback. I don’t think I have an example written for doing this but if you check the JumpProcess.jl examples you’ll see a lot of examples of mixing differential equations with jump processes.

Yes, you’d just have to make that part of the rate equation.

1 Like

A birth reaction like S --> 2S at rate k has probability per time to occur of k S(t). So the larger the S population is, the smaller the average timestep in a jump process simulation, which makes the simulation slower and slower as the population grows.

The best way to get around this would likely be to have a model that switches between the exact jump process simulation and a tau-leaping approximation depending on the population size, but I don’t think that can be done currently. @ChrisRackauckas do the tau-leaping solvers have the necessary event support for this?

2 Likes

Yes, the stochastic diffeq ones should work the same as the SDE solvers.

Oh cool, then I guess we should be able to handle this via callbacks right now.

exactly, you pay a double price: the computations are heavier and the time steps shorter. It is straightforward to derive the SDE approximation from the CLT. Would that be enough?

I have the feeling that in the end, you will have to make approximation (see above or below)

This would be a killer! Is it done somewhere? what about numerical analysis?

We have the components in tau-leaping and jump solvers, but haven’t put together an explicit coupling algorithm. It may be that one can hack it via callbacks that turn on/off the individual components for a hybrid model, but I haven’t tried coupling jumps to tau-leaping before.

There are formal papers on such couplings, and probably more rigorous papers. Most of the references I know are of the algorithmic variety from J. Chem. Phys. and such. I vaguely recall that maybe Christof Schuette had a more mathematical paper on such a coupling, but it was a while ago.

One reference is A “partitioned leaping” approach for multiscale modeling of chemical reaction dynamics by Harris and Clancy in J. Chem. Phys. (2006).

2 Likes

Hi @ChrisRackauckas - thanks for your reply.

I’m happy to explore this further, with the caveats that I haven’t had a chance to look through the paper in detail yet and lack formal mathematical training so this would probably require some patience on your part. Would you recommend setting up the game of life example as a toy model to get some intuition for how the approach works?

Sorry can you just clarify what you mean by ‘… to specialize on’?

Finally, here is an attempt i’ve made to switch between the jump model and the ODE when some population threshold is crossed. It is almost working, but i’m sure this is not an efficient way to do this, and there is also some issue where when it switched back to the jump model there is often a sudden jump in the population size that isn’t consistent with the model… i’m thinking its probably an issue with how i’m using the callbacks. Would appreciate any feedback!

using DifferentialEquations
using JumpProcesses
using Distributions
using Plots

# Jump model:
#############

function make_jprob(u0, tspan, p)

    # Define the rate identities
    rateidxs = [1, 2]

    # Define the reactant stoichiometry
    reactant_stoich = 
    [
        [1 => 1], # birth
        [1 => 1], # death 
    ]
    # Define the net stoichiometry
    net_stoich =
    [
        [1 => 1], # birth 
        [1 => -1] # death 
    ]

    # Formulate as a mass action jump problem. 
    mass_act_jump = MassActionJump(reactant_stoich, net_stoich; 
    param_idxs=rateidxs)

    prob = DiscreteProblem(u0, tspan, p)

    jump_prob = JumpProblem(prob, Direct(), mass_act_jump,
                            save_positions= (false,false))

    return jump_prob

end

# ODE model: 
############

function make_oprob(u0, tspan, p)

    function ode_fxn(du, u, p, t)
        
        n = u[1]
        b,d = p

        du[1] = dn = (b - d)*n

    end

    ode_prob = ODEProblem(ode_fxn, u0, tspan, p)

    return ode_prob

end

# callbacks
###########

# Switch from SJM -> ODE
cond_switch1(u, t, integrator) = integrator.u[1] >= Nswitch # switches when hits
cb_switch1 = DiscreteCallback(cond_switch1, terminate!,
                              save_positions = (false, false))
# (If I saved the second position, I'd sometimes see the time jump backwards
# if the population size instantly went back up to > Nswitch?)

# Switch from ODE -> SJM
cond_switch2(u,t,integrator) = (u[1] - (Nswitch-1)) # switches when 1 less
cb_switch2 = ContinuousCallback(cond_switch2, terminate!,
                               save_positions = (false, true))

# Turn treatment on 
cond_treat1(u, t, integrator) = t ∈ treat_ons
function treat_on!(integrator)
    integrator.p[2] = d_t1
    reset_aggregated_jumps!(integrator)
    nothing
end
cb_treat1 = DiscreteCallback(cond_treat1,treat_on!,
                             save_positions = (false,false))

# Turn treatment off
cond_treat0(u, t, integrator) = t ∈ treat_offs 
function treat_off!(integrator)
    integrator.p[2] = d_t0
    reset_aggregated_jumps!(integrator)
    nothing
end
cb_treat0 = DiscreteCallback(cond_treat0,treat_off!,
                             save_positions = (false,false))

# Create a callback set for the sjm and ode models. 
cbs_sjm = CallbackSet(cb_switch1, cb_treat1, cb_treat0)
cbs_ode = CallbackSet(cb_switch2, cb_treat1, cb_treat0)

# A function that decides which model to use given the current population size. 
function mod_decis(curr_u, curr_t, tmax, p, Nswitch)

    if curr_u[1] < Nswitch 
        sjm_prob = make_jprob(curr_u, (curr_t,tmax), p)
        sol = solve(sjm_prob, SSAStepper(), saveat=0.05, tstops = drug_tvec,
                        callback = cbs_sjm)
        # Callback still returns all ts but with N fixed to Nswitch?
        if sol.retcode == :Terminated
            sol_t = sol.t[1:findfirst(sum.(sol.u) .>= Nswitch)]
            sol_u = sol.u[1:findfirst(sum.(sol.u) .>= Nswitch)]
        else
            sol_t = sol.t
            sol_u = sol.u
        end            
    else
        ode_prob = make_oprob(curr_u, (curr_t,tmax), p)
        sol = solve(ode_prob, Tsit5(), saveat=0.05, tstops = drug_tvec,
                        callback = cbs_ode)
        sol_t = sol.t
        sol_u = sol.u
    end

    return sol_t, sol_u 

end

# Parameters
############
n0 = 100; tmax = 18.0; b = 0.893; d = 0.200; 
# And collect initial conditions and parameters for the problem
u0 = [n0]; tspan = (0.0,tmax); p = [b,d]; Nswitch = 1000
# Treatment parameters: Dc controls the strength of the drug-induced kiling
Dc = 6.0; d_t1 = d * Dc; d_t0 = d;
# Treatment times 
dt = 4.0; drug_tvec = collect(dt:dt:tmax)
treat_ons = drug_tvec[1:2:length(drug_tvec)]
treat_offs = drug_tvec[2:2:length(drug_tvec)]

# Run
#####

# Keep track of current population size and time. 
curr_t = tspan[1]
curr_u = u0

# Store solve outputs across all switches
t_out = Vector{Float64}(undef, 0)
u_out = Vector{Vector{Float64}}(undef, 0)

while curr_t < tmax 

    sol = mod_decis(curr_u, curr_t, tmax, p, Nswitch)

    curr_t = last(sol[1])
    curr_u = map(x -> Int64(round(x[1])), last(sol[2]))

    append!(t_out, sol[1])
    append!(u_out, sol[2])

end

plot(t_out, map(x -> x[1], u_out), yaxis=:log10)```

If anybody stumbles across this again, I arrived at the following version as a more elegant way of having a model that switches between a stochastic jump process (labelled SJM in the code) and a deterministic ODE when some population threshold is crossed.
Instead of writing the two separate models and manually switching between then (as in my previous answer - this was cumbersome), it now includes ODE rates and SJM rates in the ODE function, and the callbacks turn these rates on and off when the population threshold (Nswitch) is crossed.

Here is a simple one-type birth-death process that does this:

using DifferentialEquations
using JumpProcesses

# Initial pop size, max time, birth and death rates. 
n0 = 10.0; tmax = 12.0; b = 1.0; d = 0.4; 

# Have separate rate parameters for births and deaths
# in the ODE and SJM problems. Set the ODE positions 
# (1 and 2) to 0.0 to begin. 

u0 = [n0]; tspan = (0.0,tmax); p = [0.0,0.0,b,d];

# ODE problem 
#############

function ode_fxn(du, u, p, t)
        
    b_o,d_o,b_j,d_j = p

    du[1] = (b_o - d_o)*u[1] 

end

ode_prob = ODEProblem(ode_fxn, u0, tspan, p)

# SJM Problem 
############

function birth!(integrator)
    integrator.u[1] = integrator.u[1] + 1
    nothing
end

function death!(integrator)
    integrator.u[1] = integrator.u[1] - 1
end

b_rate(u, p, t) = (u[1] * p[3])
d_rate(u, p, t) = (u[1] * p[4])

b_jump = VariableRateJump(b_rate, birth!)
d_jump = VariableRateJump(d_rate, death!)

# Can now turn the ODE problem into a SJM problem

sjm_prob = JumpProblem(ode_prob, Direct(), b_jump, d_jump)

# Callbacks
###########

# Switch from SJM -> ODE
cond_switch1(u, t, integrator) = integrator.u[1] >= Nswitch

# Affect for switching - turn the ODE integrator rates on and SJM rates off. 
function switch_1!(integrator)
    integrator.p[1] = b
    integrator.p[2] = d
    integrator.p[3] = 0.0
    integrator.p[4] = 0.0
    nothing
end

cb_switch1 = DiscreteCallback(cond_switch1, switch_1!, 
                              save_positions = (false, true))

# Switch from ODE -> SJM
cond_switch2(u, t, integrator) = (u[1] - (Nswitch-1)) # switches when 1 less

# Affect for switching - turning the SJM integrator rates on and ODE rates off. 
function switch_2!(integrator)
    # Round the popoulation size to discrete for the sjm.
    integrator.u[1] = round(integrator.u[1])
    integrator.p[1] = 0.0
    integrator.p[2] = 0.0
    integrator.p[3] = b
    integrator.p[4] = d
    nothing
end

cb_switch2 = ContinuousCallback(cond_switch2, switch_2!,
                                save_positions = (false, true))

# Turn into a callback set
cbs = CallbackSet(cb_switch1, cb_switch2)

# Run
#####

# Set population size to switch between model types
Nswitch = 500

sol = solve(sjm_prob, Tsit5(), callback = cbs, adaptive=false, dt = 0.5)

# Plot
#####

using Plots

plot(sol, yaxis=:log10)
# Highlight where switch occurs 
 plot!([0.0, tmax], [Nswitch, Nswitch])

1 Like

@freddie090 that is great to hear! Something like that is what I was thinking about in my comment above. Ideally we could automate this kind of switching, and have a scheduler to handle shuffling species / reactions between different scales. I think once we finish building up the variable rate solvers it would make sense to try to tackle this more systematically.

1 Like