Memory issues for long SSA simulations with mass action jumps

I’m trying to find the stationary distribution of a Master Equation where the state space is bounded but can be too large to solve using FiniteStateProjection.jl so I’m left with running Gillespie simulations.

Running long simulations to find steady states uses large amounts of memory and eventually impacts performance. The full trajectory is not actually necessary to calculate the steady state, just the time spent in each state. I’d like to just keep track of the time spent in each state which can then be used to calculate the stationary distribution at the end.

However this doesn’t seem to be possible using JumpProcesses.jl. Whenever there are variable jump rates it forces every state to be saved when in reality only the previous state needs to be stored for the next time step.

Below I’ve put a simple example of how callbacks can be used to track the cumulative time in each state but the solver records the full trajectory anyway.

Is there something I’m missing or could JumpProccess be extended to allow only the previous state to be recorded?

using Catalyst
using DifferentialEquations

two_state_model = @reaction_network begin
    (k1, k2), X1 <--> X2
end

p = (:k1 => 1.0, :k2 => 2.0)
N = 50
u₀ = [:X1 => N, :X2 => 0]
tspan = (0.0, 1e3)

# track cumulative time spent in each possible state (x1, x2)
# when normalized this gives the ss distribution
accumulated_time = zeros((N + 1, N + 1))
last_t = 0.0
function aggregate_state!(integrator)
    x1, x2 = integrator.u
    dt = integrator.t - last_t
    x1_max, x2_max = accumulated_time.size
    if x1 < x1_max && x2 < x2_max
        accumulated_time[x1+1, x2+1] += dt
    end
    global last_t = integrator.t
end
# update aggregated values each time step
cb = DiscreteCallback((u, t, integrator) -> true, aggregate_state!)

prob = DiscreteProblem(two_state_model, u₀, tspan, p)
# setting save_positions does nothing because the problem has VariableRateJumps 
jump_prob = JumpProblem(two_state_model, prob, Direct(), save_positions=(false, false))
# not able to set save_everystep=false as it's set manually by JumpProblem
sol = solve(jump_prob, SSAStepper(); callback=cb, save_end=true)
ss = accumulated_time ./ sol.t[end]

By default, callbacks and VariableRateJumps save the state both before and after their affect is called. Since your callback condition is always true the state is always getting saved every jump right before and after calling your affect. You can control this by passing save_positions = (false, false) to the callback. Here is an updated version of your code that uses the current Catalyst API and shows how to turn off such saving.

using Catalyst
using JumpProcesses
using SciMLBase

two_state_model = @reaction_network begin
    (k1, k2), X1 <--> X2
end

p = (:k1 => 1.0, :k2 => 2.0)
N = 50
u₀ = [:X1 => N, :X2 => 0]
tspan = (0.0, 1e3)

# track cumulative time spent in each possible state (x1, x2)
# when normalized this gives the ss distribution
accumulated_time = zeros((N + 1, N + 1))
last_t = 0.0
function aggregate_state!(integrator)
    x1, x2 = integrator.u
    dt = integrator.t - last_t
    x1_max, x2_max = accumulated_time.size
    if x1 < x1_max && x2 < x2_max
        accumulated_time[x1+1, x2+1] += dt
    end
    global last_t = integrator.t
end
# update aggregated values each time step
cb = DiscreteCallback((u, t, integrator) -> true, aggregate_state!; save_positions=(false, false))

jinputs = JumpInputs(two_state_model, u₀, tspan, p)
# setting save_positions does nothing because the problem has VariableRateJumps 
jump_prob = JumpProblem(jinputs; save_positions=(false, false))
# not able to set save_everystep=false as it's set manually by JumpProblem
sol = solve(jump_prob; callback=cb)
ss = accumulated_time ./ sol.t[end]

I’d suggest also taking a look through Event Handling and Callback Functions · DifferentialEquations.jl to understand how callbacks work. In particular, you might want to use a function-object to store your values you are modifying and avoid global variables. The AutoAbsTol example shows how to do this: Event Handling and Callback Functions · DifferentialEquations.jl.