Hi,
I 'm trying to estimate HPV transmission model parameters with DiffeqFlux but seems like it gets stuck in local optima and I’m not sure about using mini batching because my observation data just has length of 6 in T . I started to work by studying the tutorial on parameter estimation of stiff systems in documentation so f increases are also allowed.I’m really stuck with that issue and I don’t even know where to start any help will make my day.
Thanks.
Mevlüt.
using DifferentialEquations, DiffEqFlux, LinearAlgebra
using ForwardDiff
using DiffEqBase: UJacobianWrapper
using Plots
using Random
Random.seed!(15000)
const g = 2
const s = 3
const s_ = 3
const a = 4
const a_ = 4
const gendergr = range(0,stop=1,length=g)
const sriskgr = range(0,stop=1,length=s)
const agegr = range(0,stop=1,length=a)
const TF = 6
const timefinale = range(0,stop=1,length=TF)
const prevf1 = 0.0746
const prevf2 = 0.0048
const prevm1 = 0.0194
const prevm2 = 0.0194
const Rsf1 = [0.475,0.3675,0.1575]
const Rsf2 = [0.809,0.1337,0.0573]
const Rsf3 = [0.929,0.0497,0.0213]
const Rsf4 = [0.955,0.0315,0.0135]
const Rsf5 = [0.979,0.0147,0.0063]
const Rsf6 = [0.971,0.0203,0.0087]
const Rsf7 = [0.976,0.0168,0.0072]
const Rsf8 = [0.9808,0.01344,0.00576]
const Rsm1 = [0.0986,0.63098,0.27042]
const Rsm2 = [0.5254,0.33222,0.14238]
const Rsm3 = [0.821,0.1253,0.0537]
const Rsm4 = [0.9201,0.05593,0.02397]
const Rsm5 = [0.9628,0.026,0.0112]
const Rsm6 = [0.971,0.0203,0.0087]
const Rsm7 = [0.9807,0.0135,0.0058]
const Rsm8 = [0.9907,0.00651,0.00279]
const coverall = 0.456
const Nfemale = [3211000.0, 3198000.0, 3100000.0, 3131000.0, 3269000.0, 2859000.0, 2499000.0, 3617000.0]
const Nmale = [3371000.0, 3303000.0, 3179000.0, 3201000.0, 3316000.0, 2888000.0, 2536000.0, 3425000.0]
const NF = sum(Nfemale)
const NM = sum(Nmale)
function init_probmat(dz,js,ka)
Z = length(dz)
L = length(js)
J = length(js)
K = length(ka)
M = length(ka)
probmat = zeros(Z,J,L,K,M)
probmat[1,1,1,:,:] = I(K)
probmat[1,2,2,:,:] = I(K)
probmat[1,3,3,:,:] = I(K)
probmat[2,1,1,:,:] = I(K)
probmat[2,2,2,:,:] = I(K)
probmat[2,3,3,:,:] = I(K)
probmat[1,2,3,:,:] = I(K)
probmat[1,3,2,:,:] = I(K)
probmat[2,2,3,:,:] = I(K)
probmat[2,3,2,:,:] = I(K)
probmat .= ifelse.(isnan.(probmat), 0, probmat)
probmat
end
const probmat = init_probmat(gendergr,sriskgr,agegr)
function init_ngsa(dz,js,ka)
Z = length(dz)
J = length(js)
K = length(ka)
Ngsa = zeros(Z,J,K)
Ngsa[1,:,1] = (Rsf1* Nfemale[1]) + (Rsf2* Nfemale[2])
Ngsa[1,:,2] = (Rsf3* Nfemale[3]) + (Rsf4* Nfemale[4])
Ngsa[1,:,3] = (Rsf5* Nfemale[5]) + (Rsf6* Nfemale[6])
Ngsa[1,:,4] = (Rsf7* Nfemale[7]) + (Rsf8* Nfemale[8])
Ngsa[2,:,1] = (Rsm1* Nmale[1]) + (Rsm2* Nmale[2])
Ngsa[2,:,2] = (Rsm3* Nmale[3]) + (Rsm4* Nmale[4])
Ngsa[2,:,3] = (Rsm5* Nmale[5]) + (Rsm6* Nmale[6])
Ngsa[2,:,4] = (Rsm7* Nmale[7]) + (Rsm8* Nmale[8])
Ngsa
end
function init_rgsa(dz,js,ka)
J = length(js)
G = length(dz)
K = length(ka)
Rgsa = zeros(G,J,K)
Rfsa = zeros(J,K)
Rmsa = zeros(J,K)
Rs = [1.0,3.0,5.0]
Ra = [4.0,3.0,2.0,1.0]
Rfsa = Rs*transpose(Ra)
Rmsa = Rs*transpose(Ra)
Rgsa[1,:,:] = Rfsa
Rgsa[2,:,:] = Rmsa
Rgsa
end
const Ngsa = init_ngsa(gendergr,sriskgr,agegr)
const Rgsa = init_rgsa(gendergr,sriskgr,agegr)
const Ngsa_Rgsa = Ngsa.*Rgsa
const cmin = coverall*(sum(Ngsa)/sum(Ngsa_Rgsa))
const Cgsa = cmin .* Rgsa
function init_balmat(Cgsa,probmat,Ngsa)
balmat = zeros(s,s_,a,a_)
@inbounds for i in 1:s,j in 1:a,k in 1:s_,l in 1:a_
balmat[i,k,j,l] = (Cgsa[1,i,j]*probmat[1,i,k,j,l]*Ngsa[1,i,j]) / (Cgsa[2,i,j]*probmat[2,i,k,j,l]*Ngsa[2,i,j])
end
balmat .= ifelse.(isnan.(balmat), 0, balmat)
balmat
end
const balmat = init_balmat(Cgsa,probmat,Ngsa)
function init_Cmat(Cgsa,balmat)
Cmat = zeros(g,s,s_,a,a_)
@inbounds for i in 1:s,j in 1:a,k in 1:s_,l in 1:a_
Cmat[1,i,k,j,l] = Cgsa[1,i,j]*balmat[i,k,j,l]^0
Cmat[2,i,k,j,l] = Cgsa[2,i,j]*balmat[i,k,j,l]^(1)
end
Cmat
end
const Cmat = init_Cmat(Cgsa,balmat)
actcmult_ = Cmat[1,:,:,:,:]
actcmult_2 = Cmat[2,:,:,:,:]
const cstar = 1.0
actcnew = mapslices(sum, actcmult_, dims = [2,4])
actcnew2 = dropdims(actcnew; dims=2)
actcnew3 = dropdims(actcnew2; dims=3)
actcmult = cstar * actcnew3
actcnew2 = mapslices(sum, actcmult_2, dims = [2,4])
actcnew22 = dropdims(actcnew; dims=2)
actcnew23 = dropdims(actcnew22; dims=3)
actcmult2 = cstar * actcnew23
function ODEHPV(du,u,p,t)
DIf, DIPf1, DIPf2, c1, c2, PCC1, PCC2, PRG1, PRG2,DIm, DIPm1, DIPm2, Beta, Beta_m, SCRN, DTR, DC, cm1, cm2 = p
@inbounds for j = 1, i in 1:s
du[i,j,1] = (u[i,j,13]+u[i,j,14])/DIf - u[i,j,1]*Beta*actcmult[i,j]*((u[i,j,16]/Ngsa[2,i,j])+(u[i,j,17]/Ngsa[2,i,j])) + u[i,j,8]/DC + u[i,j,9]/DTR + u[i,j,10]/DTR + u[i,j,11]/DTR + u[i,j,12]/DTR - 0.1*u[i,j,1] + 0.05*Rsf1[i]*sum(u[:,4,:])
du[i,j,2] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) + u[i,j,4]*(PCC1) + u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,2]/DIPf1 - u[i,j,2]*SCRN - 0.1*u[i,j,2]
du[i,j,3] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) + u[i,j,5]*(PCC1) + u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,3]/DIPf2 - u[i,j,3]*SCRN- 0.1*u[i,j,3]
du[i,j,4] = u[i,j,2]*(1-c1)/DIPf1 + u[i,j,6]*PCC2 - u[i,j,4]*PRG1 - u[i,j,4]*(PCC1) - u[i,j,4]*SCRN - 0.1*u[i,j,4]
du[i,j,5] = u[i,j,3]*(1-c2)/DIPf2 + u[i,j,7]*PCC2 - u[i,j,5]*PRG1 - u[i,j,5]*(PCC1) - u[i,j,5]*SCRN - 0.1*u[i,j,5]
du[i,j,6] = u[i,j,4]*PRG1 - u[i,j,6]*PCC2 - u[i,j,6]*SCRN - u[i,j,6]*PRG2 - 0.1*u[i,j,6]
du[i,j,7] = u[i,j,5]*PRG1 - u[i,j,7]*PCC2 - u[i,j,7]*SCRN - u[i,j,7]*PRG2 - 0.1*u[i,j,7]
du[i,j,8] = (u[i,j,6]+u[i,j,7])*PRG2 - u[i,j,8]*SCRN -u[i,j,8]/DC - 0.1*u[i,j,8]
du[i,j,9] = (u[i,j,2]+u[i,j,3])*SCRN - u[i,j,9]/DTR - 0.1*u[i,j,9]
du[i,j,10] = (u[i,j,4]+u[i,j,5])*SCRN - u[i,j,10]/DTR - 0.1*u[i,j,10]
du[i,j,11] = (u[i,j,6]+u[i,j,7])*SCRN - u[i,j,11]/DTR - 0.1*u[i,j,11]
du[i,j,12] = u[i,j,8]*SCRN - u[i,j,12]/DTR - 0.1*u[i,j,12]
du[i,j,13] = u[i,j,2]*(c1)/DIPf1 - u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,13]/DIf - 0.1*u[i,j,13]
du[i,j,14] = u[i,j,3]*(c2)/DIPf2 - u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,14]/DIf - 0.1*u[i,j,14]
du[i,j,15] = (u[i,j,18]+u[i,j,19])/DIm - u[i,j,15]*Beta_m*actcmult2[i,j]*((u[i,j,3]/Ngsa[1,i,j])+(u[i,j,2]/Ngsa[1,i,j])) - 0.1*u[i,j,15] + 0.05*Rsm1[i]*(sum(u[:,4,:]))
du[i,j,16] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) + u[i,j,19]*Beta*actcmult[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - u[i,j,16]*cm1/DIPm1 - 0.1*u[i,j,16]
du[i,j,17] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) + u[i,j,18]*Beta*actcmult[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - u[i,j,17]*cm2/DIPm2 - 0.1*u[i,j,17]
du[i,j,18] = u[i,j,16]*cm1/DIPm1 - u[i,j,18]/DIm - u[i,j,18]*Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - 0.1*u[i,j,18]
du[i,j,19] = u[i,j,17]*cm2/DIPm2 - u[i,j,19]/DIm - u[i,j,19]*Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - 0.1*u[i,j,19]
end
@inbounds for j in 2:a, i in 1:s
du[i,j,1] = (u[i,j,13]+u[i,j,14])/DIf - u[i,j,1]*Beta*actcmult[i,j]*((u[i,j,16]/Ngsa[2,i,j])+(u[i,j,17]/Ngsa[2,i,j])) + u[i,j,8]/DC + u[i,j,9]/DTR + u[i,j,10]/DTR + u[i,j,11]/DTR + u[i,j,12]/DTR - 0.1*u[i,j,1] + 0.1*u[i,j-1,1]
du[i,j,2] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) + u[i,j,4]*(PCC1) + u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,2]/DIPf1 - u[i,j,2]*SCRN- 0.1*u[i,j,2] + 0.1*u[i,j-1,2]
du[i,j,3] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) + u[i,j,5]*(PCC1) + u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,3]/DIPf2 - u[i,j,3]*SCRN- 0.1*u[i,j,3] + 0.1*u[i,j-1,3]
du[i,j,4] = u[i,j,2]*(1-c1)/DIPf1 + u[i,j,6]*PCC2 - u[i,j,4]*PRG1 - u[i,j,4]*(PCC1) - u[i,j,4]*SCRN - 0.1*u[i,j,4] + 0.1*u[i,j-1,4]
du[i,j,5] = u[i,j,3]*(1-c2)/DIPf2 + u[i,j,7]*PCC2 - u[i,j,5]*PRG1 - u[i,j,5]*(PCC1)- u[i,j,5]*SCRN - 0.1*u[i,j,5] + 0.1*u[i,j-1,5]
du[i,j,6] = u[i,j,4]*PRG1 - u[i,j,6]*PCC2 - u[i,j,6]*SCRN - u[i,j,6]*PRG2 - 0.1*u[i,j,6] + 0.1*u[i,j-1,6]
du[i,j,7] = u[i,j,5]*PRG1 - u[i,j,7]*PCC2 - u[i,j,7]*SCRN - u[i,j,7]*PRG2 - 0.1*u[i,j,7] + 0.1*u[i,j-1,7]
du[i,j,8] = (u[i,j,6]+u[i,j,7])*PRG2 - u[i,j,8]*SCRN -u[i,j,8]/DC - 0.1*u[i,j,8] + 0.1*u[i,j-1,8]
du[i,j,9] = (u[i,j,2]+u[i,j,3])*SCRN - u[i,j,9]/DTR - 0.1*u[i,j,9] + 0.1*u[i,j-1,9]
du[i,j,10] = (u[i,j,4]+u[i,j,5])*SCRN - u[i,j,10]/DTR - 0.1*u[i,j,10] + 0.1*u[i,j-1,10]
du[i,j,11] = (u[i,j,6]+u[i,j,7])*SCRN - u[i,j,11]/DTR - 0.1*u[i,j,11] + 0.1*u[i,j-1,11]
du[i,j,12] = u[i,j,8]*SCRN - u[i,j,12]/DTR - 0.1*u[i,j,12] + 0.1*u[i,j-1,12]
du[i,j,13] = u[i,j,2]*(c1)/DIPf1 - u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,13]/DIf - 0.1*u[i,j,13] + 0.1*u[i,j-1,13]
du[i,j,14] = u[i,j,3]*(c2)/DIPf2 - u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,14]/DIf - 0.1*u[i,j,14] + 0.1*u[i,j-1,14]
du[i,j,15] = (u[i,j,18]+u[i,j,19])/DIm - u[i,j,15]*Beta_m*actcmult2[i,j]*((u[i,j,3]/Ngsa[1,i,j])+(u[i,j,2]/Ngsa[1,i,j])) - 0.1*u[i,j,15]+ 0.1*u[i,j-1,15]
du[i,j,16] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) + u[i,j,19]*Beta*actcmult[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - u[i,j,16]*cm1/DIPm1 - 0.1*u[i,j,16]+ 0.1*u[i,j-1,16]
du[i,j,17] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) + u[i,j,18]*Beta*actcmult[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - u[i,j,17]*cm2/DIPm2 - 0.1*u[i,j,17]+ 0.1*u[i,j-1,17]
du[i,j,18] = u[i,j,16]*cm1/DIPm1 - u[i,j,18]/DIm - u[i,j,18]*Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - 0.1*u[i,j,18] + 0.1*u[i,j-1,18]
du[i,j,19] = u[i,j,17]*cm2/DIPm2 - u[i,j,19]/DIm - u[i,j,19]*Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - 0.1*u[i,j,19] + 0.1*u[i,j-1,19]
end
end
function init_mevotoar(js,ka)
J = length(js)
K = length(ka)
u = zeros(J, K, 19)
u[:,:,1] = (1-prevf1-prevf2)* Ngsa[1,:,:]
u[:,:,2] = prevf1* Ngsa[1,:,:]
u[:,:,3] = prevf2* Ngsa[1,:,:]
u[:,:,4] = zeros(J,K)
u[:,:,5] = zeros(J,K)
u[:,:,6] = zeros(J,K)
u[:,:,7] = zeros(J,K)
u[:,:,8] = zeros(J,K)
u[:,:,9] = zeros(J,K)
u[:,:,10] = zeros(J,K)
u[:,:,11] = zeros(J,K)
u[:,:,12] = zeros(J,K)
u[:,:,13] = zeros(J,K)
u[:,:,14] = zeros(J,K)
u[:,:,15] = (1-prevm1-prevm2)* Ngsa[2,:,:]
u[:,:,16] = prevm1* Ngsa[2,:,:]
u[:,:,17] = prevm2* Ngsa[2,:,:]
u[:,:,18] = zeros(J,K)
u[:,:,19] = zeros(J,K)
u
end
const u0 = init_mevotoar(sriskgr,agegr)
p = [25.0,1.425,0.986,0.81,0.93,0.10774,0.0308,0.05146,0.00508,25.0,1.016,0.525,0.4,0.4,0.04,8.0,10.0,0.670,0.538]
tspan = (0.0,5.0)
prob = ODEProblem(ODEHPV,u0,tspan,p)
sol = solve(prob,Vern7())
data_01s = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
data_02s = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
data_03s = Float64[0.01940,0.01940,0.0194,0.0194,0.0194,0.0194]
data_04s = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
data_01 = Float64[0.0,1500.0,3750.0,5600.0,7200.0,8000.0]
data_02 = Float64[0.0,615.0,1275.0,1990.0,2860.0,3650.0]
data_03 = Float64[0.0,120.0,240.0,360.0,480.0,600.0]
data_04 = Float64[0.0,15.0,30.0,45.0,60.0,75.0]
datas1 = reshape(data_01s, (1, 6))
datas2 = reshape(data_02s, (1, 6))
datas3 = reshape(data_03s, (1, 6))
datas4 = reshape(data_04s, (1, 6))
data_01 = reshape(data_01, (1, 6))
data_02 = reshape(data_02, (1, 6))
data_03 = reshape(data_03, (1, 6))
data_04 = reshape(data_04, (1, 6))
datasf1 = Array{Float64}(undef, 3,4,6)
datasf2 = Array{Float64}(undef, 3,4,6)
datasm1 = Array{Float64}(undef, 3,4,6)
datasm2 = Array{Float64}(undef, 3,4,6)
dataincf01 = Array{Float64}(undef, 3,4,6)
dataincf02 = Array{Float64}(undef, 3,4,6)
dataincm01 = Array{Float64}(undef, 3,4,6)
dataincm02 = Array{Float64}(undef, 3,4,6)
for i in 1:3, j in 1:4
datasf1[i,j,:] = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
datasf2[i,j,:] = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
datasm1[i,j,:] = Float64[0.01940,0.01940,0.01940,0.01940,0.01940,0.01940]
datasm2[i,j,:] = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
dataincf01[i,j,:] = Float64[0.0,1500.0,3750.0,5600.0,7200.0,8000.0]
dataincf02[i,j,:] = Float64[0.0,615.0,1275.0,1990.0,2860.0,3650.0]
dataincm01[i,j,:] = Float64[0.0,120.0,240.0,360.0,480.0,600.0]
dataincm02[i,j,:] = Float64[0.0,15.0,30.0,45.0,60.0,75.0]
end
data01 = Array{Float64}(undef, 3,4,8,6)
prediction = Array{Float64}(undef, 3,4,8,6)
data01[:,:,1,:] = datasf1
data01[:,:,2,:] = datasf2
data01[:,:,3,:] = datasm1
data01[:,:,4,:] = datasm2
data01[:,:,5,:] = dataincf01
data01[:,:,6,:] = dataincf02
data01[:,:,7,:] = dataincm01
data01[:,:,8,:] = dataincm02
ts = Float64[0.0,1.0,2.0,3.0,4.0,5.0]
Js = map(u->I + 0.1*ForwardDiff.jacobian(UJacobianWrapper(ODEHPV, 0.0, p), u), sol.u)
function predict_adjoint(p)
p = p
_prob = remake(prob,p=p)
odedata = Array(solve(_prob, Vern7(),saveat=ts,sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
estimated01 = odedata[:,:,2:3,:] ./ NF
estimated02 = odedata[:,:,16:17,:] ./ NM
estimated03 = odedata[:,:,9:12,:]
prediction[:,:,1:2,:] = estimated01
prediction[:,:,3:4,:] = estimated02
prediction[:,:,5:8,:] = estimated03
prediction
end
function loss_adjoint(p)
prediction = predict_adjoint(p)
diff = map((u,data) -> abs2.(u .- data) , prediction, data01)
loss = sum(abs, sum(diff)) |> sqrt
loss, prediction
end
cb = function (p,l,pred) #callback function to observe training
println("Loss: $l")
println("Parameters: $(p)")
# using `remake` to re-create our `prob` with current parameters `p`
plot(solve(remake(prob, p=p), Vern7())) |> display
return false # Tell it to not halt the optimization. If return true, then optimization stops
end
initp = [25.0,1.425,0.986,0.81,0.93,0.10774,0.0308,0.05146,0.00508,25.0,1.016,0.525,0.4,0.4,0.04,8.0,10.0,0.670,0.538]
cb(initp,loss_adjoint(initp)...)
res = DiffEqFlux.sciml_train(loss_adjoint, initp, ADAM(0.01), cb = cb, maxiters = 150000)
res2 = DiffEqFlux.sciml_train(loss_adjoint, res.u, BFGS(), cb = cb, maxiters = 1500000, allow_f_increases=true)