Improving the speed of large population simulations in DifferentialEquations.jl

Hi @ChrisRackauckas - thanks for your reply.

I’m happy to explore this further, with the caveats that I haven’t had a chance to look through the paper in detail yet and lack formal mathematical training so this would probably require some patience on your part. Would you recommend setting up the game of life example as a toy model to get some intuition for how the approach works?

Sorry can you just clarify what you mean by ‘… to specialize on’?

Finally, here is an attempt i’ve made to switch between the jump model and the ODE when some population threshold is crossed. It is almost working, but i’m sure this is not an efficient way to do this, and there is also some issue where when it switched back to the jump model there is often a sudden jump in the population size that isn’t consistent with the model… i’m thinking its probably an issue with how i’m using the callbacks. Would appreciate any feedback!

using DifferentialEquations
using JumpProcesses
using Distributions
using Plots

# Jump model:
#############

function make_jprob(u0, tspan, p)

    # Define the rate identities
    rateidxs = [1, 2]

    # Define the reactant stoichiometry
    reactant_stoich = 
    [
        [1 => 1], # birth
        [1 => 1], # death 
    ]
    # Define the net stoichiometry
    net_stoich =
    [
        [1 => 1], # birth 
        [1 => -1] # death 
    ]

    # Formulate as a mass action jump problem. 
    mass_act_jump = MassActionJump(reactant_stoich, net_stoich; 
    param_idxs=rateidxs)

    prob = DiscreteProblem(u0, tspan, p)

    jump_prob = JumpProblem(prob, Direct(), mass_act_jump,
                            save_positions= (false,false))

    return jump_prob

end

# ODE model: 
############

function make_oprob(u0, tspan, p)

    function ode_fxn(du, u, p, t)
        
        n = u[1]
        b,d = p

        du[1] = dn = (b - d)*n

    end

    ode_prob = ODEProblem(ode_fxn, u0, tspan, p)

    return ode_prob

end

# callbacks
###########

# Switch from SJM -> ODE
cond_switch1(u, t, integrator) = integrator.u[1] >= Nswitch # switches when hits
cb_switch1 = DiscreteCallback(cond_switch1, terminate!,
                              save_positions = (false, false))
# (If I saved the second position, I'd sometimes see the time jump backwards
# if the population size instantly went back up to > Nswitch?)

# Switch from ODE -> SJM
cond_switch2(u,t,integrator) = (u[1] - (Nswitch-1)) # switches when 1 less
cb_switch2 = ContinuousCallback(cond_switch2, terminate!,
                               save_positions = (false, true))

# Turn treatment on 
cond_treat1(u, t, integrator) = t ∈ treat_ons
function treat_on!(integrator)
    integrator.p[2] = d_t1
    reset_aggregated_jumps!(integrator)
    nothing
end
cb_treat1 = DiscreteCallback(cond_treat1,treat_on!,
                             save_positions = (false,false))

# Turn treatment off
cond_treat0(u, t, integrator) = t ∈ treat_offs 
function treat_off!(integrator)
    integrator.p[2] = d_t0
    reset_aggregated_jumps!(integrator)
    nothing
end
cb_treat0 = DiscreteCallback(cond_treat0,treat_off!,
                             save_positions = (false,false))

# Create a callback set for the sjm and ode models. 
cbs_sjm = CallbackSet(cb_switch1, cb_treat1, cb_treat0)
cbs_ode = CallbackSet(cb_switch2, cb_treat1, cb_treat0)

# A function that decides which model to use given the current population size. 
function mod_decis(curr_u, curr_t, tmax, p, Nswitch)

    if curr_u[1] < Nswitch 
        sjm_prob = make_jprob(curr_u, (curr_t,tmax), p)
        sol = solve(sjm_prob, SSAStepper(), saveat=0.05, tstops = drug_tvec,
                        callback = cbs_sjm)
        # Callback still returns all ts but with N fixed to Nswitch?
        if sol.retcode == :Terminated
            sol_t = sol.t[1:findfirst(sum.(sol.u) .>= Nswitch)]
            sol_u = sol.u[1:findfirst(sum.(sol.u) .>= Nswitch)]
        else
            sol_t = sol.t
            sol_u = sol.u
        end            
    else
        ode_prob = make_oprob(curr_u, (curr_t,tmax), p)
        sol = solve(ode_prob, Tsit5(), saveat=0.05, tstops = drug_tvec,
                        callback = cbs_ode)
        sol_t = sol.t
        sol_u = sol.u
    end

    return sol_t, sol_u 

end

# Parameters
############
n0 = 100; tmax = 18.0; b = 0.893; d = 0.200; 
# And collect initial conditions and parameters for the problem
u0 = [n0]; tspan = (0.0,tmax); p = [b,d]; Nswitch = 1000
# Treatment parameters: Dc controls the strength of the drug-induced kiling
Dc = 6.0; d_t1 = d * Dc; d_t0 = d;
# Treatment times 
dt = 4.0; drug_tvec = collect(dt:dt:tmax)
treat_ons = drug_tvec[1:2:length(drug_tvec)]
treat_offs = drug_tvec[2:2:length(drug_tvec)]

# Run
#####

# Keep track of current population size and time. 
curr_t = tspan[1]
curr_u = u0

# Store solve outputs across all switches
t_out = Vector{Float64}(undef, 0)
u_out = Vector{Vector{Float64}}(undef, 0)

while curr_t < tmax 

    sol = mod_decis(curr_u, curr_t, tmax, p, Nswitch)

    curr_t = last(sol[1])
    curr_u = map(x -> Int64(round(x[1])), last(sol[2]))

    append!(t_out, sol[1])
    append!(u_out, sol[2])

end

plot(t_out, map(x -> x[1], u_out), yaxis=:log10)```