JumpProblem with large number of jumps crashes

I am really thankful for your help. The code now not only runs without crashing but with your suggestions runs 30x faster than with the old implementation — a considerable improvement.

I have implemented your suggestions as below (with the updated API).

using LightGraphs
using DifferentialEquations
using Plots
using Distributions
using SparseArrays
using LinearAlgebra
using BenchmarkTools

@info "Initializing system..."
degree = 5
kdist = Poisson(degree)
N = Int(1000)
ks = 1
while sum(ks) % 2 != 0
    global ks = rand(kdist, N)
end

g = random_configuration_model(N, ks)

i₀ = rand(1:N, round(Int, N*0.1)) 
x₀ = [ones(N); zeros(N); zeros(N)]
x₀[i₀] .= 0
x₀[i₀ .+ N] .= 1
tspan = (0.0, 100.0)
p = [0.25, 0.05, degree, adjacency_matrix(g)]

function agg_sol(sol, N)
    agg = reshape(sol[:, :], (N, 3, :))
    agg = mean(agg; dims=1)
    agg = dropdims(agg; dims=1)
    agg = permutedims(agg, (2, 1))
    return agg
end

@info "Collecting jumps..."
function dNᵢ(i)

    @views function rate(u, p, t)
        @inbounds u[i]*(p[1]/p[3])*dot(p[4][:, i], u[(N+1):2N])
    end

    function affect!(integrator)
        integrator.u[i] -= 1
        integrator.u[N+i] += 1
    end

    return rate, affect!

end

infections = ConstantRateJump[]
for i in 1:N
    push!(infections, ConstantRateJump(dNᵢ(i)...))
end

recoveries_reactant = Vector{Pair{Int, Int}}[]
recoveries_net = Vector{Pair{Int, Int}}[]
for i in 1:N
    push!(recoveries_reactant, [N+i => 1])
    push!(recoveries_net, [N+i => -1, 2N+i => +1])
end
recoveries = MassActionJump(repeat([p[2]], N), recoveries_reactant, recoveries_net; scale_rates=false)

@info "Building dependency graph..."
vtoj = Vector{Vector{Int64}}(undef, 3N)
for i in 1:N
    @inbounds vtoj[i] = [N+i]
    @inbounds vtoj[N+i] = [N .+ findnz(p[4][:, i])[1]; i]
    @inbounds vtoj[2N+i] = []
end

jtov = Vector{Vector{Int64}}(undef, 2N)
@views for i in 1:N
    @inbounds jtov[i] = [N+i, 2N+i]
    @inbounds jtov[N+i] = [i, N+i]
end

jtoj = Vector{Vector{Int64}}(undef, 2N)
for i in 1:N
    jtoj[i] = [N .+ findnz(p[4][i, :])[1]; N + i; i]
    jtoj[N+i] = [N .+ findnz(p[4][i, :])[1]; N + i]
end

@info "Building discrete problem..."
discrete_prob = DiscreteProblem(x₀, tspan, p)

@info "Building jump problem..."
# see https://github.com/SciML/ModelingToolkit.jl/blob/f12f472f630fd85a6fab4ca547b1c679217c33df/src/systems/jumps/jumpsystem.jl
# see https://diffeq.sciml.ai/stable/types/jump_types/
js = JumpSet(; constant_jumps=infections, massaction_jumps=recoveries)
# jump_prob = JumpProblem(discrete_prob, RSSA(), js; vartojumps_map=vtoj, jumptovars_map=jtov)
jump_prob = JumpProblem(discrete_prob, DirectCR(), js; dep_graph=jtoj)

@info "Solving problem..."
@btime solve(jump_prob, SSAStepper())
jump_sol = solve(jump_prob, SSAStepper())

@info "Plotting..."
plot(agg_sol(jump_sol, N));
ylims!(0, 1);
plot!(legend=:right);
savefig("test5.png")

@info "Done."

With the previous implementation (without MassActionJump and the views), I got the following performance with 1,000 nodes:

182.444 ms (201195 allocations: 172.55 MiB)

With the new implementation, I got the following performance :

6.797 ms (92786 allocations: 24.99 MiB)

I can even run with 10,000 nodes in under 1 seconds and 2.2Gb allocation. (With 100,00 the program crashes, likely because I don’t have enough memory).

With regards to performance, I was not sure whether @views and @inbounds are redundant or I need both. Would you be able to tell which one is correct?

Do you see any other low hanging fruits for improving performance?

Finally, I would like to share that my pain point was figuring out the order of the jumps to build the dependency graphs. It is nowhere documented in which order the mass and constant jumps are put together in the JumpSet. The structure itself also does not have any method that gives me the mapping between jumps and their indices.

The way I got around this was to use Catalyst.jl to build an equivalent model and use the functions asgraph, variable_dependencies and eqeq_dependencies to pierce those mappings together. (I am happy to share the Catalyst.jl code for reference). It would be great to have an easier way to do this or perhaps improve the documentation with a discussion on that.