Problem with parameter estimation in a epidemiological model

Hello, I’m new to Julia. I’m trying to estimate parameter of a system of differential equations for a epidemiological model. I can solve the system but i can’t fit the parameter to my data.

function SEAIRD!(du, u, p, t)

s, e, a, i, r, d = u
α, β, γ, ξ, τ, χ, ϕ = p

ω = 0.4

λ = ϕ*β*(i + (e + a)/2)/n

du[1] = α*r -λ*s
du[2] = λ*ω*s - τ*e
du[3] = λ*(1 - ω)*s - γ*a 
du[4] = τ*e - (ξ + χ)*i
du[5] = ξ*i + γ*a -α*r
du[6] = χ*i

end

function loss(p)
sol = solve(prob, Rosenbrock23(), p=p, saveat=tsteps)
casos = sol[2,:] + sol[3,:]
erro = sum(casos - proporcao_casos)
return erro
end

e₀ = 105200
a₀ = 0
i₀ = 0
r₀ = 0
d₀ = 0
n = 44882953

s₀ = n - (e₀ + a₀ + i₀ + r₀ + d₀)

α = 0.005
β = 0.06
γ = 0.0125
ξ = 0.3
τ = 0.2
χ = 0.01
ϕ = 5

u0 = [s₀, e₀, a₀, i₀, r₀, d₀]

p = [α, β, γ, ξ, τ, χ, ϕ]

tspan = (1.0, 440.0)
tsteps = 1.0:1.0:440.0

prob = ODEProblem(SEAIRD!, u0, tspan, p)
sol = solve(prob, Rosenbrock23())

result_ode = DiffEqFlux.sciml_train(loss, p,
ADAM(0.1),
maxiters = 100)

I get the error

┌ Warning: Instability detected. Aborting └ @ SciMLBase C:\Users\fepra.julia\packages\SciMLBase\DKeLA\src\integrator_interface.jl:351 ┌ Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable. └ @ SciMLBase C:\Users\fepra.julia\packages\SciMLBase\DKeLA\src\integrator_interface.jl:345 ┌ Warning: Instability detected. Aborting └ @ SciMLBase C:\Users\fepra.julia\packages\SciMLBase\DKeLA\src\integrator_interface.jl:351

DimensionMismatch(“dimensions must match: a has dims (Base.OneTo(289),), b has dims (Base.OneTo(440),), mismatch at 1”)

Anyone can tell me what I’m doing wrong?

Handling Divergent and Unstable Trajectories · DiffEqFlux.jl explains this phenomena and what to do.

Hi, thanks for your reply. I followed the instructions but now i got the following error

ethodError: Cannot convert an object of type Nothing to an object of type Float64
Closest candidates are:
convert(::Type{T}, ::VectorizationBase.AbstractSIMD) where T<:Union{Bool, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, VectorizationBase.Bit} at C:\Users\fepra.julia\packages\VectorizationBase\y1g7W\src\base_defs.jl:151

My new loss function is

function loss(p)
tmp_prob = remake(prob, p=p)
tmp_sol = Array(solve(tmp_prob,Tsit5(),saveat=0.1))
casos = tmp_sol[2,:] + tmp_sol[3,:]
if size(casos) == size(proporcao_casos)
return sum(abs2,casos - proporcao_casos)
else
return Inf
end
end

Thanks

Hi, the parameter estimation worked, but the optimized parameter are equal the initial parameters

result_ode = DiffEqFlux.sciml_train(loss, p, ADAM(0.1), maxiters = 100)

u: 7-element Vector{Float64}:
0.005
0.06
0.0125
0.3
0.2
0.01
9.0

α = 0.005
β = 0.06
γ = 0.0125
ξ = 0.3
τ = 0.2
χ = 0.01
ϕ = 9

What’s the gradient of your loss w.r.t. p? Check with Zygote.gradient directly. My guess would be that maybe you were solving accidentally with the old parameters each time. paste what your full code is now.

Hi, here is the plot of my data

The plot of fitted parameters

Here is my code

function SEAIRD!(du, u, p, t)

s, e, a, i, r, d = u
α, β, γ, ξ, τ, χ, ϕ = p

ω = 0.4

λ = ϕ*β*(i + (e + a)/2)/n

du[1] = α*r - λ*s
du[2] = λ*ω*s - τ*e
du[3] = λ*(1 - ω)*s - γ*a 
du[4] = τ*e - (ξ + χ)*i
du[5] = ξ*i + γ*a -α*r
du[6] = χ*i

end

function loss(p)
tmp_prob = remake(prob, p=p)
tmp_sol = Array(solve(tmp_prob,Tsit5(),saveat=0.1))
casos = tmp_sol[2,:] + tmp_sol[3,:]
if size(casos) == size(proporcao_casos)
return sum(abs2, casos - proporcao_casos)
else
return Inf
end
end

prob = ODEProblem(SEAIRD!, u0, tspan, p)
sol = solve(prob, saveat=0.1)

e₀ = 1052
a₀ = 0
i₀ = 0
r₀ = 0
d₀ = 0
n = 44882953

s₀ = n - (e₀ + a₀ + i₀ + r₀ + d₀)

α = 0.005
β = 0.06
γ = 0.0125
ξ = 0.3
τ = 0.2
χ = 0.01
ϕ = 9

u0 = [s₀, e₀, a₀, i₀, r₀, d₀]

p = [α, β, γ, ξ, τ, χ, ϕ]

tspan = (1.0, 440.0)
tsteps = 1.0:1.0:440.0

tmp_prob = remake(prob, p=[0.005, 0.08, 0.025, 0.06, 0.2, 0.02, 6])
tmp = solve(tmp_prob)

pinit = [0.005, 0.08, 0.025, 0.06, 0.2, 0.02, 6]

result_ode = DiffEqFlux.sciml_train(loss, pinit, ADAM(0.1), maxiters = 100)

remade_solution = solve(remake(prob, p = result_ode.minimizer), Tsit5(), saveat = tsteps)

Here is my data array:

proporcao_casos = [1.052000e+03, 1.223000e+03, 1.406000e+03, 1.451000e+03,
1.517000e+03, 2.339000e+03, 2.981000e+03, 3.506000e+03,
4.048000e+03, 4.466000e+03, 4.620000e+03, 4.861000e+03,
5.682000e+03, 6.708000e+03, 7.480000e+03, 8.216000e+03,
8.419000e+03, 8.755000e+03, 8.895000e+03, 9.371000e+03,
1.104300e+04, 1.156800e+04, 1.284100e+04, 1.389400e+04,
1.426700e+04, 1.458000e+04, 1.538500e+04, 1.591400e+04,
1.674000e+04, 1.782600e+04, 2.000400e+04, 2.071500e+04,
2.169600e+04, 2.404100e+04, 2.615800e+04, 2.869800e+04,
3.037400e+04, 3.117400e+04, 3.177200e+04, 3.218700e+04,
3.405300e+04, 3.785300e+04, 3.992800e+04, 4.183000e+04,
4.441100e+04, 4.544400e+04, 4.613100e+04, 4.771900e+04,
5.109700e+04, 5.428600e+04, 5.837800e+04, 6.118300e+04,
6.234500e+04, 6.306600e+04, 6.599500e+04, 6.985900e+04,
7.373900e+04, 7.687100e+04, 8.055800e+04, 8.216100e+04,
8.362500e+04, 8.601700e+04, 8.948300e+04, 9.586500e+04,
1.015560e+05, 1.071420e+05, 1.096980e+05, 1.112960e+05,
1.182950e+05, 1.234830e+05, 1.292000e+05, 1.345650e+05,
1.405490e+05, 1.430730e+05, 1.445930e+05, 1.501380e+05,
1.563160e+05, 1.625200e+05, 1.679000e+05, 1.728750e+05,
1.782020e+05, 1.814600e+05, 1.902850e+05, 1.915170e+05,
1.926280e+05, 2.116580e+05, 2.157930e+05, 2.191850e+05,
2.219730e+05, 2.294750e+05, 2.388220e+05, 2.485870e+05,
2.585080e+05, 2.655810e+05, 2.717370e+05, 2.751450e+05,
2.813800e+05, 2.899350e+05, 3.021790e+05, 3.105170e+05,
3.125300e+05, 3.201790e+05, 3.230700e+05, 3.327080e+05,
3.413650e+05, 3.497150e+05, 3.591100e+05, 3.668900e+05,
3.719970e+05, 3.746070e+05, 3.866070e+05, 3.931760e+05,
4.020480e+05, 4.074150e+05, 4.120270e+05, 4.150490e+05,
4.164340e+05, 4.226690e+05, 4.394460e+05, 4.520070e+05,
4.632180e+05, 4.794810e+05, 4.839820e+05, 4.876540e+05,
5.003010e+05, 5.141970e+05, 5.290060e+05, 5.423040e+05,
5.523180e+05, 5.586850e+05, 5.602180e+05, 5.755890e+05,
5.852650e+05, 5.986700e+05, 6.083790e+05, 6.217310e+05,
6.271260e+05, 6.284150e+05, 6.395620e+05, 6.551810e+05,
6.744550e+05, 6.861220e+05, 6.975300e+05, 6.994930e+05,
7.026650e+05, 7.115300e+05, 7.213770e+05, 7.308280e+05,
7.359600e+05, 7.492440e+05, 7.541290e+05, 7.564800e+05,
7.656700e+05, 7.761350e+05, 7.844530e+05, 7.962090e+05,
8.014220e+05, 8.034040e+05, 8.043420e+05, 8.143750e+05,
8.263310e+05, 8.379780e+05, 8.450160e+05, 8.530850e+05,
8.557220e+05, 8.573300e+05, 8.587830e+05, 8.665760e+05,
8.747540e+05, 8.828090e+05, 8.906900e+05, 8.922570e+05,
8.933490e+05, 9.012710e+05, 9.094280e+05, 9.168210e+05,
9.245320e+05, 9.316730e+05, 9.353000e+05, 9.373320e+05,
9.454220e+05, 9.519730e+05, 9.582400e+05, 9.649210e+05,
9.708880e+05, 9.722370e+05, 9.731420e+05, 9.795190e+05,
9.856280e+05, 9.917250e+05, 9.973330e+05, 1.003429e+06,
1.003902e+06, 1.004579e+06, 1.010839e+06, 1.016755e+06,
1.022404e+06, 1.028190e+06, 1.034816e+06, 1.037660e+06,
1.038344e+06, 1.039029e+06, 1.045060e+06, 1.051613e+06,
1.057240e+06, 1.062634e+06, 1.063602e+06, 1.064039e+06,
1.068962e+06, 1.073261e+06, 1.076939e+06, 1.083641e+06,
1.089255e+06, 1.091980e+06, 1.092843e+06, 1.098207e+06,
1.103582e+06, 1.108860e+06, 1.113788e+06, 1.116127e+06,
1.117147e+06, 1.117795e+06, 1.118544e+06, 1.123299e+06,
1.125936e+06, 1.125936e+06, 1.125936e+06, 1.125936e+06,
1.125936e+06, 1.147451e+06, 1.150872e+06, 1.156652e+06,
1.162782e+06, 1.167422e+06, 1.168640e+06, 1.169377e+06,
1.178075e+06, 1.184496e+06, 1.191290e+06, 1.200348e+06,
1.205435e+06, 1.209588e+06, 1.210625e+06, 1.215844e+06,
1.224744e+06, 1.229267e+06, 1.233587e+06, 1.238094e+06,
1.240473e+06, 1.241653e+06, 1.250590e+06, 1.259704e+06,
1.267912e+06, 1.276149e+06, 1.285087e+06, 1.287762e+06,
1.288878e+06, 1.296801e+06, 1.306585e+06, 1.316371e+06,
1.325162e+06, 1.333763e+06, 1.334703e+06, 1.337016e+06,
1.341428e+06, 1.341428e+06, 1.361731e+06, 1.371653e+06,
1.384100e+06, 1.384100e+06, 1.388043e+06, 1.398757e+06,
1.409140e+06, 1.418491e+06, 1.422087e+06, 1.423340e+06,
1.426176e+06, 1.427752e+06, 1.440229e+06, 1.452078e+06,
1.462297e+06, 1.466191e+06, 1.467953e+06, 1.471422e+06,
1.473670e+06, 1.486551e+06, 1.501085e+06, 1.515158e+06,
1.528952e+06, 1.540513e+06, 1.546132e+06, 1.549142e+06,
1.561844e+06, 1.577119e+06, 1.590829e+06, 1.605845e+06,
1.619619e+06, 1.625339e+06, 1.628272e+06, 1.644225e+06,
1.658636e+06, 1.670754e+06, 1.679759e+06, 1.694355e+06,
1.699427e+06, 1.702294e+06, 1.715253e+06, 1.731294e+06,
1.746070e+06, 1.759957e+06, 1.773024e+06, 1.777368e+06,
1.779722e+06, 1.794019e+06, 1.807009e+06, 1.820941e+06,
1.833163e+06, 1.845086e+06, 1.849334e+06, 1.851776e+06,
1.864977e+06, 1.878802e+06, 1.889969e+06, 1.901574e+06,
1.911411e+06, 1.913598e+06, 1.915914e+06, 1.927410e+06,
1.938712e+06, 1.949459e+06, 1.960564e+06, 1.971423e+06,
1.975927e+06, 1.978477e+06, 1.990554e+06, 2.002640e+06,
2.014529e+06, 2.026125e+06, 2.037267e+06, 2.041628e+06,
2.044699e+06, 2.054867e+06, 2.068616e+06, 2.080852e+06,
2.093924e+06, 2.107687e+06, 2.113738e+06, 2.117962e+06,
2.134020e+06, 2.149561e+06, 2.164066e+06, 2.179786e+06,
2.195130e+06, 2.202983e+06, 2.208242e+06, 2.225926e+06,
2.243868e+06, 2.261360e+06, 2.280033e+06, 2.298061e+06,
2.306326e+06, 2.311101e+06, 2.332043e+06, 2.352438e+06,
2.370885e+06, 2.392374e+06, 2.410498e+06, 2.420100e+06,
2.425320e+06, 2.446680e+06, 2.469849e+06, 2.496416e+06,
2.513178e+06, 2.520204e+06, 2.527400e+06, 2.532047e+06,
2.554841e+06, 2.576362e+06, 2.597366e+06, 2.618067e+06,
2.635378e+06, 2.643534e+06, 2.648844e+06, 2.667241e+06,
2.686031e+06, 2.704098e+06, 2.722077e+06, 2.739823e+06,
2.746217e+06, 2.750300e+06, 2.769360e+06, 2.786483e+06,
2.793750e+06, 2.811562e+06, 2.827833e+06, 2.834321e+06,
2.838233e+06, 2.856225e+06, 2.873238e+06, 2.888158e+06,
2.903709e+06, 2.918044e+06, 2.923367e+06, 2.926516e+06,
2.941980e+06, 2.956210e+06, 2.969680e+06, 2.984182e+06,
2.997282e+06, 3.003067e+06, 3.006250e+06, 3.022568e+06,
3.038240e+06, 3.053889e+06, 3.069804e+06, 3.085290e+06,
3.092844e+06, 3.096845e+06, 3.112624e+06, 3.129412e+06,
3.147348e+06, 3.163859e+06, 3.180595e+06, 3.188105e+06,
3.192727e+06, 3.210204e+06, 3.226875e+06, 3.239657e+06,
3.241240e+06, 3.254893e+06, 3.265930e+06, 3.272043e+06,
3.291509e+06, 3.314631e+06, 3.334364e+06, 3.338262e+06,
3.355201e+06, 3.365160e+06, 3.370234e+06, 3.378256e+06]

Thanks

function loss(p)
    tmp_prob = remake(prob, p=p)
    tmp_sol = Array(solve(tmp_prob,Tsit5(),saveat=0.1))
    casos = tmp_sol[2,:] + tmp_sol[3,:]
    if size(casos) == size(proporcao_casos)
        return sum(abs2, casos - proporcao_casos)
    else
        @show size(casos), size(proporcao_casos)
        return Inf
    end
end
(size(casos), size(proporcao_casos)) = ((4391,), (440,))

It seems like your output doesn’t match the size of your data, and so it’s just always a cost of Inf and exits.