Ah sorry about ω, that’s my mistake, I forgot to mention this: I changed the parametrization between the posts. It used to be the rate of leaving the R compartment (i.e. 1/duration of immunity in days), but I changed it for better interpretability, so now is the duration of immunity (days) itself, and thus in the model equation it is calculated as 1/ω. So if you want to compare the params it’s 0.00273/day vs. 0.001165/day. Sorry about that. It is correctly implemented in the model. I have attached the full code including the fitting for reproducibility. The time step in the model is in days. The time step in the integration is determined by solver, I don’t think I can influence this other than adjusting the tolerance?
I expect in my model that the periodic orbit solution is stable, because the data for the disease I am working show a very consistent and stable seasonality since decades. The chains produced by the parameter sampling algorithm should converge towards a parameter set that reproduces this pattern. So I have an a priori expectation of stable solution.
I like the fact that the non-linear solver solution provides the more accurate solutions for the inits, but the time consumption is a considerable issue when fitting models. My real model takes about 10 hours to run with solution 2 (multi threading). But I could probably optimize the code to make it more efficient. That’s where my Julia skills end.
For fitting, I am using the DifferentialEvolutionMCMC.jl library (a genetic algorithm), but the Turing.jl library should give a similar result (hopefully). There is a small difference in how it is set up with regards to the likelihood. I prefer to write it out, but that’s a matter of taste. The output is an array (?) of parameter draws (posterior distribution) for each iteration you run it (minus burn-in). You can then randomly sample from these and forward simulate the model with each param set to obtain different model trajectories to judge the fit.
Yes, disease modelling has gotten some fame (and infamy) in the past year… The discussion we are having is extremely interesting and useful for me. I am a computational epidemiologist, not a mathematician, so I love learning from people with a sounder knowledge of the math underlying it all.
using StatsPlots, NLsolve, LinearAlgebra, DifferentialEquations, DifferentialEvolutionMCMC, StatsBase, Distributions, Random, LabelledArrays
# SIR ODE model with seasonally forced contact rate
function SIR!(du,u,p,t)
# states
S, I, R = u # Susceptibles, Infected, Recovered
N = S + I + R
# params
β = p.β # average rate of infection
η = p.η # amplitude around average rate of infection
φ = p.φ # phase shift
ω = p.ω # duration of immunity
μ = p.μ # mortality/reproduction rate (assuming rectangular demography with 80 years life span)
σ = p.σ # rate of leaving I compartment (1/duration of infectiousness)
# FOI
βeff = β * (1.0+η*sin(2.0*π*(t-φ)/365.0)) # periodically forced infection rate
λ = βeff*I/N # force of infection FOI
# change in states
du[1] = (μ*N - λ*S - μ*S + (1/ω)*R) #S
du[2] = (λ*S - σ*I - μ*I) #I
du[3] = (σ*I - μ*R - (1/ω)*R) # R
du[4] = (σ*I) # C, cumulative incidence
end
# Initiate ODE problem
solvsettings = (abstol = 1.0e-8,
reltol = 1.0e-8,
maxiters = 1e5,
saveat = 7.0, # corresponds to weekly data
solver = AutoVern7(Rodas5())) # others: AutoTsit5(Rosenbrock23()) #Tsit5()
theta_fix = (μ=1.0/(80*365), σ=1.0/5.0)
theta_est = (β=0.28, η=0.07, φ=20, ω=365)
p = merge(theta_est, theta_fix)
u0 = @LArray [9999.0,1.0,0.0,1.0] (:S,:I,:R,:C)
N = sum(u0[1:3])
tmax = 365.0*3
problem = ODEProblem(SIR!,u0,[0.0,tmax],p)
inits = solve(problem,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)
plot(inits(0:tmax,idxs=2), inits(0:tmax,idxs=1))
plot(inits)
# functions to translate between u = [S I R] and x = [S I]
function u2x(u)
return u[1:2]
end
u2x(u0) # test
function x2u(x, N)
return [x[1:2] ; N-sum(x) ; x[2]]
end
x2u(u0[1:2], N) # test
# Wrapper for ODE
function update(problem, x, p, N, solvsettings)
problem_new = remake(problem; u0=x2u(x, N), p=p)
sol_new = solve(problem_new,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)
return u2x(sol_new[end])
end
update(problem, u2x(inits[end]), p, N, solvsettings) # test
# solve for periodic solution
solution = nlsolve(x -> update(problem, x, p, N, solvsettings)-x, u2x(inits[end]))
# check difference
xstar = solution.zero
@show norm(xstar - update(problem, xstar,p, N, solvsettings))
# convert from x = SI to u = SIR
ustar = x2u(xstar, N)
@show sum(ustar[1:3])
# Plug new inits in function and plot
final = solve(remake(problem; u0=ustar, tspan=[0.0,3650.0]), solvsettings.solver)
plot(final(0:3650,idxs=2), final(0:3650,idxs=1)) # yay, looks great
plot(final)
Ncheck = final(0:3650, idxs=1) .+ final(0:3650, idxs=2) .+ final(0:3650, idxs=3)
minimum(Ncheck) # some numerical imprecisions, but not dramatic
# Fitted -------------------------------------------------------------------------------------
# Calculate incident cases from cumulative incidence
function get_incidence(sol)
incidence = [sol.u[t].C - sol.u[t-1].C for t = 2:length(sol.u)]
incidence = ifelse.(incidence .< 0.0, 0.0, incidence) # to exclude small negative incidences due to numerical imprecision in C
return incidence
end
# Fake some data
data = solve(remake(problem; u0=u0, tspan=[0.0,25*365.0], p=p),
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)
plot(data)
data = get_incidence(data)
data = data[length(data)-5*52 : length(data)] # take the last ~5 years of data for fitting, discard rest
plot(data)
data = round.(data .* rand(Uniform(0.5,1.5), size(data)))
scatter!(data,legend = false)
# Priors
priors = (
β = (Uniform(0.0,1.0),),
η = (Uniform(0.0,1.0),),
ω = (Uniform(1.0, 5.0*365.0),),
φ = (Uniform(0.0,364.0),)
)
# Upper and lower bounds (in this case equivalent to arguments of the Uniform)
bounds = (
(0.0,1.0),
(0.0,1.0),
(1.0, 5.0*365.0),
(0.0,364.0)
)
parnames = keys(priors)
# Solution 1) Inits are calculated in a separate model run-----------------------------------------
tspan = [0.0, maximum(collect(range(0, step = 7, length = length(data))))]
# There is a problem with unstable solutions for the nlsolve:
# "Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable"
function loglik(data, problem, theta_fix, parnames, tspan, solvsettings, N, θ...)
theta_est = NamedTuple{parnames}(θ)
p_new = merge(theta_fix, theta_est)
# 1. Solve init model until stable periodic solution is reached
problem_new = remake(problem, p=p_new, u0=u0, tspan=[0.0,3*365.0])
inits1 = solve(problem_new,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(x->x<0.0,u),
save_everystep=false)
nlsolved = nlsolve(x -> update(problem_new, x, p, N, solvsettings)-x, u2x(inits1[2]))
# extract the zeros from the solution and convert to labelled array
inits2 = @LArray x2u(nlsolved.zero, N) symbols(problem_new.u0)
if minimum(inits2) <0.0 # discard results where periodic solution is unstable
return -Inf
end
# 2. Solve inference model with new inits
sol = solve(remake(problem; u0=inits2, tspan=tspan, p=p_new),
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)
# 3. calculate incidence and loglik
incidence = vcat(sol.u[1].C, get_incidence(sol))
ll = 0.0
for i in 1:length(incidence)
ll += logpdf(Poisson(incidence[i]), data[i])
end
#ll = sum(logpdf.(Poisson.(incidence), data))
return ll
end
loglik(data, problem, theta_fix, parnames, tspan, solvsettings, N, theta_est...) # test
# Run DE MCMC
burnin=2000
n_iter = 5000 # not enough for convergence, for benchmarking purposes only
n_groups = 3 # default is 4
Np = length(priors)*3 # N of particles per group.Default is nparam*3
model = DEModel(problem, theta_fix, parnames, tspan, solvsettings, N; priors, model=loglik, data=data)
de = DE(; n_groups, Np, bounds, burnin=burnin, priors)
@time chains1 = sample(model, de, MCMCThreads(), n_iter, progress=true, discard_burnin=false)
# Solution 2) Inits are simulated in the same model run-----------------------------------------
tspan2 = [-10*365.0, maximum(collect(range(0, step = 7, length = length(data))))]
function loglik2(data, problem, theta_fix, parnames, solvsettings, θ...)
theta_est = NamedTuple{parnames}(θ)
p_new = merge(theta_fix, theta_est)
problem_new = remake(problem, p=p_new)
sol = solve(problem_new,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)
incidence = get_incidence(sol)
incidence = incidence[(length(incidence)-length(data)+1):length(incidence)]
ll = 0.0
for i in 1:length(incidence)
ll += logpdf(Poisson(incidence[i]), data[i])
end
return ll
end
# test
problem2 = ODEProblem(SIR!,u0,tspan2,p)
loglik2(data, problem2, theta_fix, parnames, solvsettings, theta_est...) # test
# Run DE MCMC
model2 = DEModel(problem2, theta_fix, parnames, solvsettings; priors, model=loglik2, data=data)
@time chains2 = sample(model2, de, MCMCThreads(), n_iter, progress=true, discard_burnin=false)