How to vectorize a SIR Model to add age-stratification with the DifferentialEquations.jl package?

Consider the following simple SIR model I’ve written in the DifferentialEquations.jl framework:

# Import 
using DifferentialEquations

# Parameter
𝒫 = [0.1,0.2]          # Model Parameters
ℬ = [0.8,0.1, 0.0]  # Initial condition
𝒯 = (0.0,365.0)     # Time 

# Model
function Φ!(du,u,p,t)
    S,I,R = u
    β,γ = p
    
    du[1] = -β*S*I
    du[2] = β*S*I - γ*I
    du[3] = γ*I
end

# Problem Definition
problem = ODEProblem(Φ!, ℬ, 𝒯, 𝒫)

# Problem Solution
solution = solve(problem);

I would like to extend it with age-stratification in a way that’s both scalable (at least up to 20 age groups per compartment) and compact (vector form).

In the following code chunk I try to define an uncoupled age-specific SIR with only 3 age groups:

# Parameter
𝒫 = [0.14,0.12,0.2,0.2]          # Model Parameters
ℬ = [0.9,0.9,0.9,0.1,0.1,0.1,0.0,0.0,0.0]  # Initial condition
𝒯 = (0.0,365.0)     # Time 

# Model
function Φ!(du,u,p,t)
    s1,s2,s3,i1,i2,i3,r1,r2,r3 = u 
    β1,β2,β3,γ = p
    
    du[1] = -β1*s1*i1
    du[2] = -β2*s2*i2
    du[3] = -β3*s3*i3
    
    du[4] = β1*s1*i1-γ*i1
    du[5] = β2*s2*i2-γ*i2
    du[6] = β3*s3*i3-γ*i3

    du[7] = γ*i1
    du[8] = γ*i2
    du[9] = γ*i3
end

# Problem Definition
problem = ODEProblem(Φ!, ℬ, 𝒯, 𝒫)

# Problem Solution
solution = solve(problem);

This works fine but it’s neither compact nor scalable. Can you help me vectorize it in a way to make it so?

1 Like

You could use the type Particles from MonteCarloMeasurements.jl which is made for doing vectorized operations. On your problem, it would look like this, where the different model parameters are given as a vector to Particles

using OrdinaryDiffEq, MonteCarloMeasurements

# Parameter
𝒫 = [Particles([0.14, 0.12]), Particles([0.2, 0.2])]          # Model Parameters
ℬ = eltype(𝒫).([0.9, 0.1, 0])  # Initial condition
𝒯 = (0.0,365.0)     # Time 

# Model
function Φ!(du,u,p,t)
    S,I,R = u
    β,γ = p
    
    du[1] = -β*S*I
    du[2] = β*S*I - γ*I
    du[3] = γ*I
end

# Problem Definition
problem = ODEProblem(Φ!, ℬ, 𝒯, 𝒫)

# Problem Solution
solution = solve(problem, Tsit5());
mcplot(solution.t, Matrix(Array(solution)'))

To make sure initial conditions always sum to one when there are distributions of initial conditions, it’s convenient to caluclate one from the others, e.g.,

S = Particles(...)
I = Particles(...)
R = 1 - S - I

To simulate something like an uncertainty in the initial conditions, you could do something like

𝒫 = [0.1 0.2]          # Model Parameters
S = 0.9 ± 0.005        # Initial condition
R = 0
I = 1 - S - R
ℬ = [S, I, R]
𝒯 = (0.0,365.0)     # Time 

where the ± (\pm) operator creates 2000 normally distributed Particles.

4 Likes

Just make your initial condition a matrix and then:

function Φ!(du,u,p,t)
    S = @view u[:,1]
    I = @view u[:,2]
    R = @view u[:,3]
    dS = @view du[:,1]
    dI = @view du[:,2]
    dR = @view du[:,3]
    β = @view p[1:3]
    γ = p[4]
    
    @. dS = -β*S*I
    @. dI = β*S*I-γ*I
    @. dR = γ*I
end

You can then use LabelledArrays.jl to make this style even simpler:

https://github.com/SciML/LabelledArrays.jl

2 Likes

I was considering the same thing as Chris.
https://diffeq.sciml.ai/stable/tutorials/ode_example/#ode_other_types

The solutions are equivalent once you match the indices of the initial conditions and the views that Chris proposes, e.g. by setting the initial conditions as:

ℬ = [0.9 0.9 0.9; 0.1 0.1 0.1; 0.0 0.0 0.0]'  # Initial condition

For scalability (perhaps for compactess as well), one could store the parameters as

𝒫 = [[0.14,0.12,0.2],0.2]          # Model Parameters

and avoid hard-coding the number of age-groups in Φ!

function Φ!(du,u,p,t)
    S = @view u[:,1]
    I = @view u[:,2]
    R = @view u[:,3]
    dS = @view du[:,1]
    dI = @view du[:,2]
    dR = @view du[:,3]
    β, γ = p

    @. du[:,1] = -β*S*I
    @. du[:,2] = β*S*I - γ*I
    @. du[:,3] = γ*I
end

This does have a significant performance impact when calling solve, so perhaps it’s better to count the number of age groups.

N = size(u,1)
β = @view p[1:N]
γ = p[N+1]
1 Like