I’m trying to find the stationary distribution of a Master Equation where the state space is bounded but can be too large to solve using FiniteStateProjection.jl so I’m left with running Gillespie simulations.
Running long simulations to find steady states uses large amounts of memory and eventually impacts performance. The full trajectory is not actually necessary to calculate the steady state, just the time spent in each state. I’d like to just keep track of the time spent in each state which can then be used to calculate the stationary distribution at the end.
However this doesn’t seem to be possible using JumpProcesses.jl. Whenever there are variable jump rates it forces every state to be saved when in reality only the previous state needs to be stored for the next time step.
Below I’ve put a simple example of how callbacks can be used to track the cumulative time in each state but the solver records the full trajectory anyway.
Is there something I’m missing or could JumpProccess be extended to allow only the previous state to be recorded?
using Catalyst
using DifferentialEquations
two_state_model = @reaction_network begin
(k1, k2), X1 <--> X2
end
p = (:k1 => 1.0, :k2 => 2.0)
N = 50
u₀ = [:X1 => N, :X2 => 0]
tspan = (0.0, 1e3)
# track cumulative time spent in each possible state (x1, x2)
# when normalized this gives the ss distribution
accumulated_time = zeros((N + 1, N + 1))
last_t = 0.0
function aggregate_state!(integrator)
x1, x2 = integrator.u
dt = integrator.t - last_t
x1_max, x2_max = accumulated_time.size
if x1 < x1_max && x2 < x2_max
accumulated_time[x1+1, x2+1] += dt
end
global last_t = integrator.t
end
# update aggregated values each time step
cb = DiscreteCallback((u, t, integrator) -> true, aggregate_state!)
prob = DiscreteProblem(two_state_model, u₀, tspan, p)
# setting save_positions does nothing because the problem has VariableRateJumps
jump_prob = JumpProblem(two_state_model, prob, Direct(), save_positions=(false, false))
# not able to set save_everystep=false as it's set manually by JumpProblem
sol = solve(jump_prob, SSAStepper(); callback=cb, save_end=true)
ss = accumulated_time ./ sol.t[end]