# Lightgraph + ModelToolkit optimization

Hi,

I finally got an SIR model to work with SIR combined with ModelToolkit. Here is the code:

using Plots
using DataFrames
using ModelingToolkit
using DiffEqBase
using DiffEqJump
using LightGraphs
using SpecialGraphs

const nedgess = 3000
const nvertss = 500

const dgs = DiGraph(nvertss, nedgess);
# Degree of each node: 2nd argument  5
const dgr = random_regular_digraph(nvertss, 5)

function setupSystem(graph)
nverts = length(vertices(graph))
nedges = length(edges(graph))
# Allow for different beta on each edgse
@parameters t β[1:nedges]  γ[1:nverts] ;
@variables S[1:nverts](t);
@variables I[1:nverts](t);
@variables R[1:nverts](t);

rxsS = [Reaction(β[i],[I[src(e)], S[dst(e)]], [I[dst(e)]], [1,1], [2])
for (i,e) ∈ enumerate(edges(graph))]

rxsI = [Reaction(γ[v],[I[v]], [R[v]])  # note: src to src, yet there is no edges
for v ∈ vertices(graph)]

rxs = vcat(rxsS, rxsI);
vars = vcat(S,I,R);
params = vcat(β,γ);

rs = ReactionSystem(rxs, t, vars, params);
js = convert(JumpSystem, rs);
println("Completed: convert(JumpSystem)")
S0 = ones(nverts)
I0 = zeros(nverts)
R0 = zeros(nverts)

S0[1] = 0.; # One person is infected
I0[1] = 1.;
R0[1] = 1. - S0[1] - I0[1]
vars0 = vcat(S0, I0, R0);

# Two column vectors
γ = fill(0.25, nverts);
β = fill(0.50, nedges);
params = vcat(β,γ)

initial_state = [convert(Variable,state) => vars0[i] for (i,state) in enumerate(states(js))];
initial_params = [convert(Variable,par) => params[i] for (i,par) in enumerate(parameters(js))];

tspan = (0.0,20.0)
@time dprob = DiscreteProblem(js, initial_state, tspan, initial_params)
println("Completed: DiscreteProblem")
@time jprob = JumpProblem(js, dprob, NRM())
println("Completed: JumpProblem")
@time sol = solve(jprob, SSAStepper())
println("Completed: solve")

return sol
end

function processData(sol)
nverts = nvertss
nedges = nedgess
println("nverts=$nverts, nedges=$nedges")

dfs = convert(Matrix, DataFrame(sol))
Sf = dfs[1:nverts,:]
If = dfs[nverts+1:2*nverts,:]
Rf = dfs[2*nverts+1:3*nverts,:]
Savg = (sum(Sf; dims=1)') / nverts
Iavg = (sum(If; dims=1)') / nverts
Ravg = (sum(Rf; dims=1)') / nverts
print(Savg)
return Savg, Iavg, Ravg
end

# Times: sol.t
# Solution at nodes: sol.u
# sol.u[1] |> length == 120 (3 * nverts)

sol = setupSystem(dgr)
Savg, Iavg, Ravg = processData(sol)

plot(sol.t, Savg)
plot(sol.t, Iavg)
plot(sol.t, Ravg)


A single run for tspan=(0.,20.) takes about 18.85 sec, with most of the time taken in JumpProblem (18.55 sec). The network has 500 nodes with 5 edges per node. There were only 955 discrete times. I do not have a sense on how well my code is optimized, although I do not see how poor optimization on my part could slow down the JumpProblem routine, although I do not really know.

If anybody has any insight, I would be greatly appreciative. I will try and copy some code from Simon Frost and see if I can duplicate the results. Simon did not use the ModelingKit.

2 Likes

We can definitely look into why constructing the JumpProblem is taking a while.

Out of curiosity though, is the time to build the JumpProblem that important? Usually with SSAs I’m calling solve thousands of times to collect statistics, so that is the dominant cost. Will that not be your ultimate use case? How is solve performing?

Edit: Actually I see that you gave the total time too and solve is minimal. So that is good since it should usually be the dominant time when running many samples.

That’s a performance bug on our end: we should get it fixed up. https://github.com/SciML/ModelingToolkit.jl/issues/426

3 Likes

You beat me to replying while I was writing the issue

@isaacsas: to answer your question, yes, I could run solve many times and solve is very efficient, 18.7 sec for Jump, but only 0.003 sec for solve. Note though, what happens if the graph, instead of 500 nodes has 5000 nodes, which is still a small graph. I also cut down the cost by using a call to random_regular_digraph and specifying a fixed number of neighbors for each node, namely 5. It goes without saying that I am learning and developing intuition.

I do have another question. I noticed that the call to neighbors has an allocation of 16 bytes. (I was experimenting with the library EpidemicSimulations.jl that uses LightGraphs.jl and was written by Simon Frost. I wanted to know the speeds I could expect on the same graphs. For large graphs, the function neighbors() is called many times with a monotonically increasing number of total allocations. It seems to me that this could be reduced to zero allocations using a preallocated array. I checked the source code but could not figure it out.

Here is the function update! in sir.jl in EpidemicSimulations.jl (I added the println and @time statement:

function update!(m::SIR, node::Int, n_step::Int)

println("\nUpdate! top: ")
@time if node_state(m, :infectious, node, n_step - 1)
if rand() <= rate_recovered(m)
m.states[:infectious][node, n_step] = 0
m.states[:recovered][node, n_step] = 1
end
end

println("\nNeighbors")
@time a = neighbors(m.G, node)

println("\nUpdate! bottom")
@time if node_state(m, :susceptible, node, n_step - 1)
for neighbour in neighbors(m.G, node)
if node_state(m, :infectious, neighbour, n_step - 1)
if rand() <= rate_infectious(m)
m.states[:susceptible][node, n_step] = 0
m.states[:infectious][node, n_step] = 1
break
end
end
end
end

nothing
end


I agree, the JumpProblem timing needs to get fixed. But we can definitely do that, so that calling solve many times should (hopefully) be the dominant cost on your real networks.

I’m not sure about the neighbors issue; I haven’t used LightGraphs much myself. It might be better to ask about that in a separate issue aimed at people familiar with that library.

I like your example, if you could report back your timings for the different libraries at some point it would be interesting to know. I’d suggest also playing with other of our SSAs (RSSA, DirectCR and RSSACR) to see if any of them are faster than NRM (once we get this bug fixed and you are working on your desired problem size).

Finally, one last comment; you should make your initial conditions integers like S0[1]=0 if you mean to use the jump process that counts the number of each species.

neighbors() is called many times with a monotonically increasing number of total allocations. It seems to me that this could be reduced to zero allocations using a preallocated array.

If you wanted to do this, it would require you to preallocate Δ(g) elements. (Just putting that out there if you were writing code.)

Would probably be nice to turn this into a benchmark.

1 Like

True. When we get all the updates settled down we can add it with the BCR model.

I love it. That will help me learn as well.

In the meantime, I have some questions on the jump algorithm SSA. I ran solve several times, over a time span tspan=(0,20). Sometimes, sol.t has over 900 elements, and sometimes it has only 3 or 4. That is very hard to believe. My mean time of infection is 1/\beta=2 and the average time to recovery is 1/\gamma=4. The first time I run solve, after running JumpProblem, the time array always has over 900 elements. But the additional solves have far fewer timestamps. Perhaps somebody could explain this behavior? Thanks.

1 Like

I agree! Note that even if there are 200,000 nodes (more than I need), I would need 16 bytes per node, or less than 5 Mbytes, quite acceptable.

What is the meaning of the \Delta(g) notation? Surely, it does not mean Laplacian.

If the first event is a recovery, then woops it’s a short simulation. That might be what’s going on? I’d have to see a plot of those cases but that’s my guess given what I’ve seen with SIR models

You only have one infected initially right? Are you sure it isn’t just that sometimes the first thing to happen is the one infected becomes recovered, and so nothing else can happen?

Now you can tell I am a beginner. I will test with 5 initial infected and see what happens. The problem disappears. I created a little function:

    function infect(i)
S0[i] = 0; # One person is infected
I0[i] = 1;
R0[i] = 1 - S0[1] - I0[1]
end

infect.([1,10,15,25,45])


Yes, I could have chosen 5 nodes at random. Could not do this off the top of my head Thank you both!

I do have a question. When you run solve many times, how do you average over the runs. Of course, averaging at the last time: tspan[2] is trivial. But what about intermediate times? Does DifferentialEquations.jl allow interpolation when doing jump calculations? I will go check.

It does piecewise constant interpolation, so you get the exact value at the time.

You can also look at the documentation on EnsembleProblems, which are a way to have it run many simulations (even in parallel).

You might also want to use the keyword argument save_positions=(false,false) within the JumpProblem call, and use the saveat parameter within the solve call to set how often to save. This can reduce memory use quite a bit on longer simulation runs. (The former turns off saving the state on every jump event, which is needed for the exact path reconstruction with piecewise constant interpolation.)

Excellent suggestions! Thank you. I read about these two items in the past, but cannot keep everything in memory!

Δ(g) will hopefully be much less than nv(g). In any case, remember to squash(g) to make sure you’re using the smallest possible datatype.

1 Like

Yes, EnsembleProblem is something to look into. It does look like all the solutions are stored in memory,
which is problematic for large problems (I am not saying my problems are large!)
I am interested in examining the source code, because I tried to compute my own standard deviation,
incrementally, and could not do it. The format of sol confuses me. EnsembleProblem will avoid these issues.

Here is what I tried:


tspan = (0.0,20.0)
t_dense = 0.:.5:tspan[2]
@time jprob = setupSystem(dgr);
@time solm = solve(jprob, SSAStepper())(t_dense).u;
@time solsq = solm .* solm;

nruns = 30;
for i in 1:nruns

sol0 = solve(jprob, SSAStepper())(t_dense).u
solm = solm + sol0
solsq = solsq .+ sol0 .* sol0
end
solmean = solm ./ nruns
solvariance = ((solsq ./ nruns) .- (solmean .* solmmean))

# solmean .^ 2 could also be used (in principle)