# Early termination of ODE integration with callback

Hi all

I have a simple, epidemiological ODE model with three states: S(usceptibles), I(nfected) and R(ecovered). Transmission is modelled with a periodical, seasonal function. Depending on the initial conditions and parameter values, the I state converges towards stable yearly peaks of equal peak size or alternating peak size (see pics).

I would like the integration to stop when this “steady state” is reached, i.e. when the peaks have a consistent size. It isn’t a steady state in the mathematical sense, since the equations depend on time (through the seasonal forcing). I thought about doing with this a continuous callback function to capture the time point when this happens for the first time. The states at this moment would then be used as initial values for another model. However, I have some difficulties with the callback function. The main idea is to monitor the second state (I) and to terminate when its value is equal to its value 2 years (730 days) ago. I realize this is not the best condition to choose, but I couldn’t think of anything better for now.

Below is my attempt, which throws an error. I appreciate any help. Bw, Fabienne

using DifferentialEquations, StatsBase, StatsPlots, Random, StatsBase, LabelledArrays, DiffEqCallbacks

Random.seed!(42)

# ODE model: simple SIR model with seasonally forced contact rate
function SIR!(du,u,p,t)

# states
S, I, R = u
N = S + I + R

# params
β = p.β
η = p.η
φ = p.φ
ω = p.ω
μ = p.μ
σ = p.σ

# FOI
βeff = β * (1.0+η*sin(2.0*π*(t-φ)/365.0))
λ = βeff*I/N

# change in states
du[1] = μ*N - λ*S - μ*S + ω*R
du[2] = λ*S - σ*I - μ*I
du[3] = σ*I - μ*R - ω*R

end

# callback function
function condition_terminate(u,t,integrator)
floor(abs(u[2][t]-u[2][t-730]))
end
affect!(integrator) = terminate!(integrator)
cb_terminate = ContinuousCallback(condition_terminate, affect!)

# Solver settings
tspan = (0.0, 7300.0)
abstol = 1.0e-8
reltol = 1.0e-8
maxiters = 1e7
saveat = 1.0
solver = AutoVern7(Rodas5())

# Initiate ODE problem
p = @LArray [0.28,0.07,20,1.0/365,1.0/(80*365),1.0/5.0] (:β,:η,:φ,:ω,:μ,:σ)
u0 = @LArray [9999.0,1.0,0.0] (:S,:I,:R)

# ODE problem and solution
problem = ODEProblem(SIR!,u0,tspan,p)
sol = solve(problem, solver,
abstol=abstol, reltol=reltol,
maxiters=maxiters,
isoutofdomain=(u,p,t)->any(x->x<0.0,u),
callback=cb_terminate,
saveat=saveat)


throws the following error:

ERROR: LoadError: MethodError: no method matching getindex(::Float64, ::Float64)
Closest candidates are:
getindex(::Real, ::Num) at C:\Users\micky\.julia\packages\Symbolics\h8kPL\src\register.jl:51
getindex(::Real, ::SymbolicUtils.Symbolic) at C:\Users\micky\.julia\packages\Symbolics\h8kPL\src\register.jl:51
getindex(::Number) at number.jl:75


The syntax u[2][t] is the problem, u is the current time state vector so indexing it once with 2 gets the second component of the vector. To access solution values at interpolated times, you need to use the solution object stored in the integrator.

This might work (untested)

function condition_terminate(u,t,integrator)
t>730 ? floor(abs(u[2] - integrator.sol(t-730, idxs=2))) : 1.0
end

3 Likes

awesome, the termination works. The only problem now is, that I have to refine the termination condition, so that it compares the values at peak, and not just any value. Right now it stops after 791 days, because the condition is fulfilled already at day 61 in the upshoot of the first peak. Not what I intended, but I am one step closer to the solution. Does anyone have a suggestion/idea how to refine this criterion? TIA

The alternative would be to run the model for a defined period and have some (post integration) checks on the solution to compare the peak sizes, and run for longer if conditions are not met (and so on until conditions met). But if there is a more elegant callback solution, I’d prefer it.

Edit: actually changing the conditional to something like t>x*730.0 ? to have at least 2x years of integration already improves this. I might just go with this (hacky but working) solution. Thanks @contradict for the code, I appreciate it.

I have a few ideas, and I can’t decide if they are more or less hacky than what you have working already.

Starting with probably still hacky, since you always want to stop at a peak you could try adding the absolute value of the derivative to your condition:

function condition_terminate(u,t,integrator)
du = similiar(u)
SIR!(du, u, integrator.p, t)
t>730 ? floor(abs(u[2] - integrator.sol(t-730, idxs=2))) + abs(du[2]): 1.0
end


You could try evaluating at multiple delays, maybe the same way you already are or if you wanted to get fancy you could try using a DDEProblem and integrate the solution autocorrelation function at a few fixed delays as you go.

Another option is to add a band-pass filter to your integration and stop when it exceeds some threshold.

1 Like

ah that’s smart. I wasn’t sure why it would be necessary to run the ODE function again within the condition function, so I modified the code a bit to access the stored du. I am not sure if this is the correct way, but it does not throw an error (it runs smoothly). However, there is no termination when this model is run for 20 years because this condition just never happens (du never seems to be caught when 0). So I guess one would have to make the steps of the integration really small, but that would probably make things really slow so that in the end it there is no advantage over just running the model for a pre-specified longer time. I might just do the latter for now. Running out of ideas. Thanks anyway.

function condition_terminate(u,t,integrator)
t>730.0 ? floor(abs(u[2] - integrator.sol(t-730.0, idxs=2))) + abs(integrator.sol(t, Val{1}, idxs=2)) : 1.0
end


I had another idea to compare u AND du for t and t-730 with the following code:

function condition_terminate(u,t,integrator)
cond1 = floor(abs(u[2] - integrator.sol(t-730.0, idxs=2)))
cond2 = floor(abs(integrator.sol(t, Val{1}, idxs=2)-integrator.sol(t-730.0, Val{1}, idxs=2)), digits=2)
#println(cond1 + cond2)
t>4*730.0 ? (cond1 + cond2) : 1.0
end


The condition (cond1 + cond2) become zero at some point (checked with printing to the console), but the integration does not stop. Not sure what the issue is.

I feel like the right conceptual framework and solution method for this problem is the following. You have a 3d nonautonomous differential equation with a periodic forcing function, say \dot{x} = f(x,t) where f(x,t+T) = f(x,t) for some fixed T. Solutions of this system tend to converge towards a stable periodic orbit x_{po}(t) which also has period T, x_{po}(t+T) = x_{po}(t).

The way to find the stable periodic orbit x_{po}(t) is to define a Poincare map \phi : x(t_0) \rightarrow x(t_0+T), which maps any state at time t_0 into its T-time forward time integration (using a numerical time integrator). The fixed constant t_0 is arbitrary and might as well be t_0=0. The map \phi is now just a 3d map; continuous time has been factored out. You find the periodic orbit by finding a fixed point of \phi, that is, a solution x^* of the nonlinear system x = \phi(x). You can do this either by iterating \phi or by finding a root of x - \phi(x) = 0 with a nonlinear solver.

Finding the root with a nonlinear solver gives you the exact fixed point x^* (up to the accuracy of your time integrator), which you can then time-integrate to get the full periodic orbit x_{po}(t). Once you have the full periodic orbit, you can analyze it any way you want, including finding the peaks by maximization.

This is all pretty easy to do in Julia with a few function definitions and a couple packages

function f(t,x)
end

function phi(x)
# define a time-T map starting at t0=0 using numerical time integration
end

xguess = [1; 0 ; 0] # or something more sensible

# iterate a few times to get close to the stable periodic orbit
for n=1:5
xguess .= phi(xguess)
end

using NLsolve

solve(x -> x - phi(x), xguess)



I’ve done this for autonomous dynamical systems such as Lorenz and Rossler to find unstable periodic orbits. If it helps I can post that code.

5 Likes

thanks @John_Gibson, that’s an interesting approach. Yours is the mathematically most correct solution, mine is the duct tape solution. Yes, it would be very helpful to see a code example, thank you.

1 Like

You’re welcome. I’ll post code that solves your problem, since you’ve given the differential equation and an initial condition. I think the Poincare section code is going to end up simpler than the termination call-back.

1 Like

Here goes. Most of the code below is my adaption of your code to use a simple Runga-Kutta integrator, because I haven’t learned DifferentialEquations.jl yet. The part that defines the Poincare map and solves for its fixed point starts about line 60.

using Plots, LinearAlgebra, NLsolve

# SIR ODE model with seasonally forced contact rate
# dx/dt = f(x, p, t), x=state, p=params, t=time
function f(x, p, t)

# states
S, I, R = x[1], x[2], x[3]
N = S + I + R

# params
β = p[1]
η = p[2]
φ = p[3]
ω = p[4]
μ = p[5]
σ = p[6]

# FOI
βeff = β * (1.0+η*sin(2.0*π*(t-φ)/365.0))
λ = βeff*I/N

# change in states
dxdt = zeros(3)
dxdt[1] = μ*N - λ*S - μ*S + ω*R
dxdt[2] = λ*S - σ*I - μ*I
dxdt[3] = σ*I - μ*R - ω*R

dxdt
end

function rk4(f, x₀, p, Δt, N, t₀=0.0)
Δt2 = Δt/2
Δt6 = Δt/6
t₁ = t₀ + (N-1)*Δt

t = range(t₀, t₁, length=N)
x = zeros(N, length(x₀))
x[1,:] .= x₀

for n = 1:N-1
xn = x[n,:]
tn = t[n]
s1 = f(xn, p, tn)
s2 = f(xn + Δt2*s1, p, tn + Δt2)
s3 = f(xn + Δt2*s2, p, tn + Δt2)
s4 = f(xn + Δt *s3, p, tn + Δt)
x[n+1, :] .=  xn + Δt6*(s1 + 2s2 + 2s3 + s4)
end
t,x
end

# Initiate ODE problem
p = [0.28; 0.07; 20; 1.0/365;  1.0/(80*365); 1.0/5.0]
x0 = [9999.0; 1.0; 0.0] # (:S,:I,:R)

############################################
# compute periodic orbit as fixed point of Poincare map

# define time t=365 map for SIR flow
function ϕ(x,p)
t0 = 0.0
t1 = 365.0
N = 365
Δt = (t1-t0)/N

t,x = rk4(f, x, p, Δt, N,t0)

x[N,:]
end

# iterate a few times to get close to periodic orbit
xguess = copy(x0)

@show xguess
for n = 0:10
xguess .= ϕ(xguess, p)
@show xguess
end

# solve for fixed point of Poincare map
using NLsolve
solution = nlsolve(x -> ϕ(x,p)-x, xguess)
xstar = solution.zero

@show solution
@show xstar
@show ϕ(xstar, p)
@show norm(xstar - ϕ(xstar,p))

t0 = 0.0
t1 = 3650.0
Δt = 1.0
N = Int(round((t1-t0)/Δt))

# integrate the original initial condition and the periodic orbit
tpo, xpo = rk4(f, xstar, p, Δt, N)
t,x = rk4(f, x0, p, Δt, N)

# plot time series
plot(t, xpo[:,2], color=:blue, width=2, label="periodic orbit")
plot!(t, x[:,2], color=:red, label="trajectory")
plot!(xlabel="t", ylabel="I(t)")
savefig("timeseries.png")

# make I vs S phase-space plot of system
tpo, xpo = rk4(f, xstar, p, Δt, 366)
plot(x[200:end,1], x[200:end,2], color=:red, label="trajectory")
plot!(xpo[:,1], xpo[:,2], color=:blue, width=2, label="periodic orbit")
plot!(xlabel="S(t)", ylabel="I(t)")
savefig("phasespace.png")


Output is

xguess = [9999.0, 1.0, 0.0]
xguess = [7065.635716770304, 0.014063564946359613, 2934.350219664753]
xguess = [7923.0650401852645, 102.51023702064641, 1974.4247227940905]
xguess = [7637.778725598237, 1.8422707160752916, 2360.379003685687]
xguess = [7159.44365980562, 1.6620939498250187, 2838.894246244554]
xguess = [7208.506374190902, 7.954651275748168, 2783.538974533347]
xguess = [7302.350356940733, 3.378467115607611, 2694.271175943651]
xguess = [7229.414478863247, 3.472326380755114, 2767.1131947559866]
xguess = [7238.969206938598, 4.270331520940551, 2756.7604615404434]
xguess = [7251.683589443475, 3.800140715207387, 2744.5162698412923]
xguess = [7242.4396291871835, 3.840150596655532, 2753.7202202161425]
xguess = [7244.108456566993, 3.929470139271008, 2751.9620732937174]
solution = Results of Nonlinear Solver Algorithm
* Algorithm: Trust-region with dogleg and autoscaling
* Starting Point: [7244.108456566993, 3.929470139271008, 2751.9620732937174]
* Zero: [7165.838129824154, 3.8425663839108384, 2721.5326627260224]
* Inf-norm of residuals: 0.000000
* Iterations: 3
* Convergence: true
* |x - x'| < 0.0e+00: false
* |f(x)| < 1.0e-08: true
* Function Calls (f): 4
* Jacobian Calls (df/dx): 4
xstar = [7165.838129824154, 3.8425663839108384, 2721.5326627260224]
ϕ(xstar, p) = [7165.838129824161, 3.8425663839108632, 2721.532662726025]
norm(xstar - ϕ(xstar, p)) = 7.77076593286149e-12


My code is wasteful in a number of ways, such as the out-of-place f(x,p,t), rk4 returning the whole t,xtime series when sometimes we throw everything but the last value away, allocating in the inner loop of ‘rk4’, etc. Plus it would be better to define the ODE only on two of S,I,R and get the third from the total being constant. But for the given purpose these things hardly matter.

6 Likes

BTW, an advantage of solving for the periodic orbit explicitly is that once you have it, you can easily track how it changes as a function of the system parameters by applying parametric continuation methods. Home · Bifurcation Analysis in Julia

1 Like

excellent, thanks so much. Will look into it in detail later this week! Thanks a lot.

1 Like

I worked through your code and adapted it a bit. I basically replaced your hand-coded RK4 solver with the functions in DifferentialEquations.jl (see “problem” and “solve”) and adjusted the phi function accordingly. I have also changed the solver algorithm to Tsit5(), which is the standard, but RK4() works equally well. The code runs smoothly and produces a stable periodic solution.

However, I ran into a problem. The resulting zeros do not sum up to the original total N (in my case 10000). This is a problem because I have a stable population in the model which should not grow or decline. The discrepancy to 10000 is larger, the further away the initial guesses are from the stable solution.

I think this is a sophisticated solution if one is interested in the periodic stability itself. For my project, I just need it to obtain sensible inits for another model, so I think I will try to work out a solution that terminates the integration with a callback, where I can be sure that the total N is equal to the initial guesses.

Thank you anyway for your solution, I learned a lot!
Bw, Fabienne

using StatsPlots, NLsolve, LinearAlgebra, DifferentialEquations

# SIR ODE model with seasonally forced contact rate
# dx/dt = f(x, p, t), x=state, p=params, t=time
function SIR!(x, p, t)

# states
S, I, R = x[1], x[2], x[3]
N = S + I + R

# params
β = p[1]
η = p[2]
φ = p[3]
ω = p[4]
μ = p[5]
σ = p[6]

# FOI
βeff = β * (1.0+η*sin(2.0*π*(t-φ)/365.0))
λ = βeff*I/N

# change in states
dxdt = zeros(3)
dxdt[1] = μ*N - λ*S - μ*S + ω*R
dxdt[2] = λ*S - σ*I - μ*I
dxdt[3] = σ*I - μ*R - ω*R

dxdt
end

# Initiate ODE problem
p = [0.28; 0.07; 20; 1.0/365;  1.0/(80*365); 1.0/5.0]
x0 = [9999.0; 1.0; 0.0] # (:S,:I,:R)
tmax = 365.0*10 # simulate for 10 years to get close to stable periodic solution
probleminit = ODEProblem(SIR!,x0,[0.0,tmax],p)
solinit = solve(probleminit, Tsit5())
plot(solinit(0:tmax,idxs=2), solinit(0:tmax,idxs=1))
plot(solinit)

problem = ODEProblem(SIR!,solinit[end],[0.0,365.0],p)
sol = solve(problem, Tsit5())
plot(sol)
plot(sol(0:365,idxs=2), sol(0:365,idxs=1))

function ϕ2(problem, x, p)
problem_new = remake(problem, u0=x, p=p)
sol_new = solve(problem_new,
Tsit5())
#isoutofdomain=(u,p,t)->any(x->x<0.0,u))
return sol_new[end]
end

# check if works
ϕ2(problem, sol[end], p)

# Solve
solution2 = nlsolve(x -> ϕ2(problem, x, p)-x, sol[end])
solution2.zero

# check difference
@show norm(solution2.zero - ϕ2(problem, solution2.zero,p))

# Plug new inits in function and plot
final = solve(remake(problem, u0=solution2.zero, tspan=[0.0,3650.0]), Tsit5())
plot(final(0:3650,idxs=2), final(0:3650,idxs=1))
plot(final)

sum(solution2.zero) # --> this is a problem, should be equal to sum(x0)

N = final(0:3650, idxs=1) .+ final(0:3650, idxs=2) .+ final(0:3650, idxs=3)
minimum(N)
mean(N)

3 Likes

Ah, good point. I should have known that the nonlinear solver would not respect the sum(x) = constant constraint. I’ve redefined \phi as a map on x = [S,I] with R determined by the constraint S+I+R = const, and that fixes the problem. Here’s the relevant code.

Npopulation = sum(u0)

# functions to translate between u = [S I R] and x = [S I]
u2x(u) = u[1:2]
x2u(x, N=Npopulation) = [x[1:2] ; N-sum(x)]

function ϕ(problem, x, p)
problem_new = remake(problem, u0=x2u(x), p=p)
sol_new = solve(problem_new, Tsit5())
return u2x(sol_new[end])
end

# check if works
ϕ(problem, u2x(sol[end]), p)

# solve for periodic orbit
solution = nlsolve(x -> ϕ(problem, x, p)-x, u2x(sol[end]))
xstar = solution.zero

# check difference
@show norm(xstar - ϕ(problem, xstar, p))

# convert from x = SI to u = SIR
ustar = x2u(xstar)
@show sum(ustar) # --> this is a problem, should be equal to sum(x0)

# Plug new inits in function and plot
final = solve(remake(problem, u0=ustar, tspan=[0.0,3650.0]), Tsit5())
plot(final(0:3650,idxs=2), final(0:3650,idxs=1))
plot(final)


Output is

norm(xstar - ϕ(problem, xstar, p)) = 9.636691624635262e-9
ustar = [7250.055941101794, 3.8322070898320515, 2746.111851808374]
sum(ustar) = 10000.0


Thanks for the DifferentialEquations translation. If you ever want to pursue this further, let me know. I’ve worked on these problems in the context of turbulent fluid flows. I do that in C++ but have been using Julia in teaching the methods in lower-dimensional contexts. And I’d be very interested in learning more about epidemiological models and contributing there if I can.

3 Likes

Thanks @John_Gibson this is fantastic, it works like a charm. You also don’t need to have initial guesses close to the stable solution. I can run it for 2-3 years and use the last values as init guesses, and it will still be correct. This saves a lot of computational time.

My real model is a bit more complex (12 compartments). I will now try to scale this up and benchmark to see if your solution is faster than just running the model for ~10 years. I’ll post the results. If I continue with this, I’ll get back to you either on the forum or your work email (unh edu).

2 Likes

I made up some data and fitted this SIR toy model in a Bayesian framework with two versions:

1. your solution that evaluates whether a stable periodic orbit is reached and which yields the inits which are then fed into the model again for the fitting process
2. a solution, where the model just runs for 10 years prior to the data to allow it to reach periodic stability.

The benchmark unfortunately is not in favour of solution 1: solution 1 took 13 times as long to fit compared to solution 2 (all algorithm parameters being identical).

The problem is that some parameter combinations do not lead to stable periodic solutions (e.g.  (μ = 3.424657534246575e-5, σ = 0.2, β = 0.00742823720731039, η = 0.7879291468271944, ω = 858.7279867934733, φ = 186.71200800843744)). In these cases the nlsolver does not return meaningful zeros and throws a warning:
Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable.

I can “ignore” these instances by having the loglikelihood function return -Inf, which means the sampling algorithm will try to leave this specific parameters space. However, this seems very inefficient. I would probably obtain the correct solution if I run this long enough, but the time it takes to get there is just too long compared to solution 2.

I have no other idea how to address the issue of unstable solutions within a fitting framework, so I think I will have to go back to solution 2 (or the callback termination), unfortunately. Too bad, this would have been a very clean solution. Maybe it will be useful for another project. Thank you anyway for your time and help, I really appreciate it.

If the periodic orbit solution is strongly stable and you only need it approximately, it’s probably going to be faster to find it by pure time integration, as in 2. If it’s weakly or marginally stable, the nonlinear solver approach will probably be faster. If it’s unstable, only the nonlinear solver will find it.

But the solver method could still be better, even if it’s slower, because you can get the orbit as precisely as you want, and so eliminate any noise a cruder approximation would insert in subsequent calculations. If that makes any difference to you.

The second set of parameter values you give is very different from the first. For example, the first has \omega at about 0.00273, the second has it at 858.9. Looking at the equations these lead to vastly different time scales in the ODE. That \omega=858 value makes the timescale for R very small, on the order of two minutes, I think (1/858 * 24 * 60 = 1.7). Is that realistic? To get stable time integrations for this you need to take the time step down to \Delta t = 0.001, whereas before I could simulate at \Delta t = 1 (one day).

Is the Bayesian framework producing parameter values for the ODE simulations? Is there some way to guard its output so it gives only realistic parameters?

BTW, I am really enjoying working on this problem and would be very happy to continue, if it helps you. It’s related to my research in fluid mechanics and pretty close to teaching I was doing this spring. And ever since the pandemic hit I’ve wished I could contribute to epidemiological modeling. So I’ll go on as long as you find it helpful.

2 Likes

https://turing.ml/dev/tutorials/10-bayesiandiffeq/

Transform the parameters, if the domains are easy. Otherwise, you can use sol.retcode == :Success to check if the solver failed, and reject bad steps.

1 Like

Ah sorry about ω, that’s my mistake, I forgot to mention this: I changed the parametrization between the posts. It used to be the rate of leaving the R compartment (i.e. 1/duration of immunity in days), but I changed it for better interpretability, so now is the duration of immunity (days) itself, and thus in the model equation it is calculated as 1/ω. So if you want to compare the params it’s 0.00273/day vs. 0.001165/day. Sorry about that. It is correctly implemented in the model. I have attached the full code including the fitting for reproducibility. The time step in the model is in days. The time step in the integration is determined by solver, I don’t think I can influence this other than adjusting the tolerance?

I expect in my model that the periodic orbit solution is stable, because the data for the disease I am working show a very consistent and stable seasonality since decades. The chains produced by the parameter sampling algorithm should converge towards a parameter set that reproduces this pattern. So I have an a priori expectation of stable solution.

I like the fact that the non-linear solver solution provides the more accurate solutions for the inits, but the time consumption is a considerable issue when fitting models. My real model takes about 10 hours to run with solution 2 (multi threading). But I could probably optimize the code to make it more efficient. That’s where my Julia skills end.

For fitting, I am using the DifferentialEvolutionMCMC.jl library (a genetic algorithm), but the Turing.jl library should give a similar result (hopefully). There is a small difference in how it is set up with regards to the likelihood. I prefer to write it out, but that’s a matter of taste. The output is an array (?) of parameter draws (posterior distribution) for each iteration you run it (minus burn-in). You can then randomly sample from these and forward simulate the model with each param set to obtain different model trajectories to judge the fit.

Yes, disease modelling has gotten some fame (and infamy) in the past year… The discussion we are having is extremely interesting and useful for me. I am a computational epidemiologist, not a mathematician, so I love learning from people with a sounder knowledge of the math underlying it all.

using StatsPlots, NLsolve, LinearAlgebra, DifferentialEquations, DifferentialEvolutionMCMC, StatsBase, Distributions, Random, LabelledArrays

# SIR ODE model with seasonally forced contact rate
function SIR!(du,u,p,t)

# states
S, I, R = u # Susceptibles, Infected, Recovered
N = S + I + R

# params
β = p.β # average rate of infection
η = p.η # amplitude around average rate of infection
φ = p.φ # phase shift
ω = p.ω # duration of immunity
μ = p.μ # mortality/reproduction rate (assuming rectangular demography with 80 years life span)
σ = p.σ # rate of leaving I compartment (1/duration of infectiousness)

# FOI
βeff = β * (1.0+η*sin(2.0*π*(t-φ)/365.0)) # periodically forced infection rate
λ = βeff*I/N # force of infection FOI

# change in states
du[1] = (μ*N - λ*S - μ*S + (1/ω)*R) #S
du[2] = (λ*S - σ*I - μ*I) #I
du[3] = (σ*I - μ*R - (1/ω)*R) # R
du[4] = (σ*I) # C, cumulative incidence

end

# Initiate ODE problem
solvsettings = (abstol = 1.0e-8,
reltol = 1.0e-8,
maxiters = 1e5,
saveat = 7.0, # corresponds to weekly data
solver = AutoVern7(Rodas5())) # others: AutoTsit5(Rosenbrock23()) #Tsit5()

theta_fix = (μ=1.0/(80*365), σ=1.0/5.0)
theta_est = (β=0.28, η=0.07, φ=20, ω=365)
p = merge(theta_est, theta_fix)
u0 = @LArray [9999.0,1.0,0.0,1.0] (:S,:I,:R,:C)
N = sum(u0[1:3])
tmax = 365.0*3
problem = ODEProblem(SIR!,u0,[0.0,tmax],p)
inits = solve(problem,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)

plot(inits(0:tmax,idxs=2), inits(0:tmax,idxs=1))
plot(inits)

# functions to translate between u = [S I R] and x = [S I]
function u2x(u)
return u[1:2]
end

u2x(u0) # test

function x2u(x, N)
return [x[1:2] ; N-sum(x) ; x[2]]
end

x2u(u0[1:2], N) # test

# Wrapper for ODE
function update(problem, x, p, N, solvsettings)
problem_new = remake(problem; u0=x2u(x, N), p=p)
sol_new = solve(problem_new,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)
return u2x(sol_new[end])
end

update(problem, u2x(inits[end]), p, N, solvsettings) # test

# solve for periodic solution
solution = nlsolve(x -> update(problem, x, p, N, solvsettings)-x, u2x(inits[end]))

# check difference
xstar = solution.zero
@show norm(xstar - update(problem, xstar,p, N, solvsettings))

# convert from x = SI to u = SIR
ustar = x2u(xstar, N)
@show sum(ustar[1:3])

# Plug new inits in function and plot
final = solve(remake(problem; u0=ustar, tspan=[0.0,3650.0]), solvsettings.solver)
plot(final(0:3650,idxs=2), final(0:3650,idxs=1)) # yay, looks great
plot(final)

Ncheck = final(0:3650, idxs=1) .+ final(0:3650, idxs=2) .+ final(0:3650, idxs=3)
minimum(Ncheck) # some numerical imprecisions, but not dramatic

# Fitted -------------------------------------------------------------------------------------

# Calculate incident cases from cumulative incidence
function get_incidence(sol)
incidence = [sol.u[t].C - sol.u[t-1].C  for t = 2:length(sol.u)]
incidence = ifelse.(incidence .< 0.0, 0.0, incidence) # to exclude small negative incidences due to numerical imprecision in C
return incidence
end

# Fake some data
data = solve(remake(problem; u0=u0, tspan=[0.0,25*365.0], p=p),
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)

plot(data)
data = get_incidence(data)
data = data[length(data)-5*52 : length(data)] # take the last ~5 years of data for fitting, discard rest
plot(data)
data = round.(data .* rand(Uniform(0.5,1.5), size(data)))
scatter!(data,legend = false)

# Priors
priors = (
β = (Uniform(0.0,1.0),),
η = (Uniform(0.0,1.0),),
ω = (Uniform(1.0, 5.0*365.0),),
φ = (Uniform(0.0,364.0),)
)

# Upper and lower bounds (in this case equivalent to arguments of the Uniform)
bounds = (
(0.0,1.0),
(0.0,1.0),
(1.0, 5.0*365.0),
(0.0,364.0)
)

parnames = keys(priors)

# Solution 1) Inits are calculated in a separate model run-----------------------------------------

tspan = [0.0, maximum(collect(range(0, step = 7, length = length(data))))]

# There is a problem with unstable solutions for the nlsolve:
# "Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable"
function loglik(data, problem, theta_fix, parnames, tspan, solvsettings, N, θ...)

theta_est = NamedTuple{parnames}(θ)
p_new = merge(theta_fix, theta_est)

# 1. Solve init model until stable periodic solution is reached
problem_new = remake(problem, p=p_new, u0=u0, tspan=[0.0,3*365.0])
inits1 = solve(problem_new,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(x->x<0.0,u),
save_everystep=false)

nlsolved = nlsolve(x -> update(problem_new, x, p, N, solvsettings)-x, u2x(inits1[2]))

# extract the zeros from the solution and convert to labelled array
inits2 = @LArray x2u(nlsolved.zero, N) symbols(problem_new.u0)

if minimum(inits2) <0.0 # discard results where periodic solution is unstable
return -Inf
end
# 2. Solve inference model with new inits
sol = solve(remake(problem; u0=inits2, tspan=tspan, p=p_new),
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)

# 3. calculate incidence and loglik
incidence = vcat(sol.u[1].C, get_incidence(sol))

ll = 0.0
for i in 1:length(incidence)
ll += logpdf(Poisson(incidence[i]), data[i])
end

#ll = sum(logpdf.(Poisson.(incidence), data))
return ll
end

loglik(data, problem, theta_fix, parnames, tspan, solvsettings, N, theta_est...) # test

# Run DE MCMC
burnin=2000
n_iter = 5000 # not enough for convergence, for benchmarking purposes only
n_groups = 3 # default is 4
Np = length(priors)*3 # N of particles per group.Default is nparam*3
model = DEModel(problem, theta_fix, parnames, tspan, solvsettings, N; priors, model=loglik, data=data)
de = DE(; n_groups, Np, bounds, burnin=burnin, priors)

# Solution 2) Inits are simulated in the same model run-----------------------------------------

tspan2 = [-10*365.0, maximum(collect(range(0, step = 7, length = length(data))))]

function loglik2(data, problem, theta_fix, parnames, solvsettings, θ...)

theta_est = NamedTuple{parnames}(θ)
p_new = merge(theta_fix, theta_est)

problem_new = remake(problem, p=p_new)
sol = solve(problem_new,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
maxiters=solvsettings.maxiters,
isoutofdomain=(u,p,t)->any(z->z<0.0,u),
saveat=solvsettings.saveat)

incidence = get_incidence(sol)
incidence = incidence[(length(incidence)-length(data)+1):length(incidence)]

ll = 0.0
for i in 1:length(incidence)
ll += logpdf(Poisson(incidence[i]), data[i])
end

return ll
end

# test
problem2 = ODEProblem(SIR!,u0,tspan2,p)
loglik2(data, problem2, theta_fix, parnames, solvsettings, theta_est...) # test

# Run DE MCMC
model2 = DEModel(problem2, theta_fix, parnames, solvsettings; priors, model=loglik2, data=data)