Dear all,
I am using DifferentialEquations.jl to solve a standard SEIR Model. The ODE looks like this:
using DifferentialEquations
using BenchmarkTools, UnPack
function seir_ode!(du,u,p,t)
(S,E,I,R,C) = u
(β,γ,ϵ) = p
N = S+E+I+R
infection = β*I/N*S
exposed = ϵ*E
recovery = γ*I
@inbounds begin
du[1] = -infection
du[2] = infection - exposed
du[3] = exposed - recovery
du[4] = recovery
du[5] = infection
end
nothing
end
#Time Span
tmax = 100.0
tspan = (0.0, tmax)
#Initial conditions
N = 1000
s₀ = 0.9
e₀ = 0.03
i₀ = 0.02
r₀ = 1.0 - s₀ - e₀ - i₀
u0 = [N*s₀, N*e₀, N*i₀, N*r₀, N - (N - N*e₀ - N*i₀ - N*r₀)] #Last element are the cumulative cases
#Parameter
β = .5
p = (β, 0.1, 0.1)
prob_ode = ODEProblem(seir_ode!,u0,tspan,p)
sol_ode = solve(prob_ode, Tsit5(), saveat = 1.0, dt = 0.1)
This works! Now, in my actual problem I have to solve an ODE in which one of the parameter, β, is time dependent. I have to sequentially estimate this parameters and need ODE information from the last state to do so, so I cannot just first sample β over time and then solve the ODE all at once, but need to do this in a loop. My current workflow looks like this:
for t in Base.OneTo(Int(tmax))
#Estimate β, here for simplicity:
β = rand([.2, .5])
# Assign new time step, states and parameter
p_new = (β, 0.1, 0.1)
u0_new = sol_ode.u[end]
tspan_new = (float(t-1), float(t))
#Solve ODE for given β
prob_ode = ODEProblem(seir_ode!, u0_new, tspan_new, p_new)
sol_ode = solve(prob_ode, Tsit5(), saveat = t, dt = 0.01)
# Do some other stuff with ODE solution
end
Q1: Is there a way that I only need to initiate prob_ode
once and can then reuse this container? ODEProblem
is immutable, and I cannot just change the values inside the struct. Parameter and timespan change, but the buffer size should be constant across iterations. At the moment my workflow seems wasteful.
Q2: If applicable, are the are any improvements I could do for the ODE itself? There are still some allocations left, but I assume most are from the creation of vectors to save the output?
using BenchmarkTools
prob_ode = ODEProblem(seir_ode!, u0, tspan, p)
@btime solve($prob_ode, $Tsit5(), saveat = $1.0, dt = $0.1) # 26.600 μs (440 allocations: 77.27 KiB)
@btime solve($prob_ode, $Tsit5(), saveat = $tmax, dt = $0.1) # 6.620 μs (42 allocations: 5.27 KiB)