# Parameter Estimation of HPV Disease Model ( DiffEqFlux ) (beginner issue)

Hi,

I 'm trying to estimate HPV transmission model parameters with DiffeqFlux but seems like it gets stuck in local optima and I’m not sure about using mini batching because my observation data just has length of 6 in T . I started to work by studying the tutorial on parameter estimation of stiff systems in documentation so f increases are also allowed.I’m really stuck with that issue and I don’t even know where to start any help will make my day.

Thanks.
Mevlüt.

``````
using DifferentialEquations, DiffEqFlux, LinearAlgebra
using ForwardDiff
using DiffEqBase: UJacobianWrapper
using Plots
using Random
Random.seed!(15000)
const g = 2
const s = 3
const s_ = 3
const a = 4
const a_ = 4
const gendergr = range(0,stop=1,length=g)
const sriskgr = range(0,stop=1,length=s)
const agegr = range(0,stop=1,length=a)
const TF = 6
const timefinale = range(0,stop=1,length=TF)
const prevf1 = 0.0746
const prevf2 = 0.0048
const prevm1 = 0.0194
const prevm2 = 0.0194
const Rsf1 = [0.475,0.3675,0.1575]
const Rsf2 = [0.809,0.1337,0.0573]
const Rsf3 = [0.929,0.0497,0.0213]
const Rsf4 = [0.955,0.0315,0.0135]
const Rsf5 = [0.979,0.0147,0.0063]
const Rsf6 = [0.971,0.0203,0.0087]
const Rsf7 = [0.976,0.0168,0.0072]
const Rsf8 = [0.9808,0.01344,0.00576]
const Rsm1 = [0.0986,0.63098,0.27042]
const Rsm2 = [0.5254,0.33222,0.14238]
const Rsm3 = [0.821,0.1253,0.0537]
const Rsm4 = [0.9201,0.05593,0.02397]
const Rsm5 = [0.9628,0.026,0.0112]
const Rsm6 = [0.971,0.0203,0.0087]
const Rsm7 = [0.9807,0.0135,0.0058]
const Rsm8 = [0.9907,0.00651,0.00279]
const coverall = 0.456
const Nfemale = [3211000.0, 3198000.0, 3100000.0, 3131000.0, 3269000.0, 2859000.0, 2499000.0, 3617000.0]
const Nmale = [3371000.0, 3303000.0, 3179000.0, 3201000.0, 3316000.0, 2888000.0, 2536000.0, 3425000.0]
const NF = sum(Nfemale)
const NM = sum(Nmale)

function init_probmat(dz,js,ka)

Z = length(dz)
L = length(js)
J = length(js)
K = length(ka)
M = length(ka)

probmat = zeros(Z,J,L,K,M)
probmat[1,1,1,:,:] = I(K)
probmat[1,2,2,:,:] = I(K)
probmat[1,3,3,:,:] = I(K)
probmat[2,1,1,:,:] = I(K)
probmat[2,2,2,:,:] = I(K)
probmat[2,3,3,:,:] = I(K)
probmat[1,2,3,:,:] = I(K)
probmat[1,3,2,:,:] = I(K)
probmat[2,2,3,:,:] = I(K)
probmat[2,3,2,:,:] = I(K)
probmat .= ifelse.(isnan.(probmat), 0, probmat)

probmat
end

const probmat = init_probmat(gendergr,sriskgr,agegr)

function init_ngsa(dz,js,ka)

Z = length(dz)
J = length(js)
K = length(ka)
Ngsa = zeros(Z,J,K)
Ngsa[1,:,1] = (Rsf1* Nfemale[1]) + (Rsf2* Nfemale[2])
Ngsa[1,:,2] = (Rsf3* Nfemale[3]) + (Rsf4* Nfemale[4])
Ngsa[1,:,3] = (Rsf5* Nfemale[5]) + (Rsf6* Nfemale[6])
Ngsa[1,:,4] = (Rsf7* Nfemale[7]) + (Rsf8* Nfemale[8])
Ngsa[2,:,1] = (Rsm1* Nmale[1]) + (Rsm2* Nmale[2])
Ngsa[2,:,2] = (Rsm3* Nmale[3]) + (Rsm4* Nmale[4])
Ngsa[2,:,3] = (Rsm5* Nmale[5]) + (Rsm6* Nmale[6])
Ngsa[2,:,4] = (Rsm7* Nmale[7]) + (Rsm8* Nmale[8])

Ngsa
end

function init_rgsa(dz,js,ka)

J = length(js)
G = length(dz)
K = length(ka)
Rgsa = zeros(G,J,K)
Rfsa = zeros(J,K)
Rmsa = zeros(J,K)
Rs = [1.0,3.0,5.0]
Ra = [4.0,3.0,2.0,1.0]
Rfsa = Rs*transpose(Ra)
Rmsa = Rs*transpose(Ra)
Rgsa[1,:,:] = Rfsa
Rgsa[2,:,:] = Rmsa
Rgsa
end

const Ngsa = init_ngsa(gendergr,sriskgr,agegr)
const Rgsa = init_rgsa(gendergr,sriskgr,agegr)
const Ngsa_Rgsa = Ngsa.*Rgsa

const cmin = coverall*(sum(Ngsa)/sum(Ngsa_Rgsa))
const Cgsa = cmin .* Rgsa

function init_balmat(Cgsa,probmat,Ngsa)

balmat = zeros(s,s_,a,a_)

@inbounds for i in 1:s,j in 1:a,k in 1:s_,l in 1:a_
balmat[i,k,j,l] = (Cgsa[1,i,j]*probmat[1,i,k,j,l]*Ngsa[1,i,j]) / (Cgsa[2,i,j]*probmat[2,i,k,j,l]*Ngsa[2,i,j])
end
balmat .= ifelse.(isnan.(balmat), 0, balmat)
balmat
end

const balmat = init_balmat(Cgsa,probmat,Ngsa)

function init_Cmat(Cgsa,balmat)

Cmat = zeros(g,s,s_,a,a_)

@inbounds for i in 1:s,j in 1:a,k in 1:s_,l in 1:a_
Cmat[1,i,k,j,l] = Cgsa[1,i,j]*balmat[i,k,j,l]^0
Cmat[2,i,k,j,l] = Cgsa[2,i,j]*balmat[i,k,j,l]^(1)

end
Cmat
end

const Cmat = init_Cmat(Cgsa,balmat)

actcmult_ = Cmat[1,:,:,:,:]
actcmult_2 = Cmat[2,:,:,:,:]
const cstar = 1.0
actcnew = mapslices(sum, actcmult_, dims = [2,4])
actcnew2 = dropdims(actcnew; dims=2)
actcnew3 = dropdims(actcnew2; dims=3)
actcmult =  cstar * actcnew3

actcnew2 = mapslices(sum, actcmult_2, dims = [2,4])
actcnew22 = dropdims(actcnew; dims=2)
actcnew23 = dropdims(actcnew22; dims=3)
actcmult2 =  cstar * actcnew23

function ODEHPV(du,u,p,t)
DIf, DIPf1, DIPf2, c1, c2, PCC1, PCC2, PRG1, PRG2,DIm, DIPm1, DIPm2, Beta, Beta_m, SCRN, DTR, DC, cm1, cm2  = p

@inbounds for j = 1, i in 1:s
du[i,j,1] = (u[i,j,13]+u[i,j,14])/DIf - u[i,j,1]*Beta*actcmult[i,j]*((u[i,j,16]/Ngsa[2,i,j])+(u[i,j,17]/Ngsa[2,i,j])) + u[i,j,8]/DC + u[i,j,9]/DTR + u[i,j,10]/DTR + u[i,j,11]/DTR + u[i,j,12]/DTR - 0.1*u[i,j,1] + 0.05*Rsf1[i]*sum(u[:,4,:])
du[i,j,2] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j])  + u[i,j,4]*(PCC1) + u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,2]/DIPf1 - u[i,j,2]*SCRN - 0.1*u[i,j,2]
du[i,j,3] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j])  + u[i,j,5]*(PCC1) + u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,3]/DIPf2 - u[i,j,3]*SCRN- 0.1*u[i,j,3]
du[i,j,4] = u[i,j,2]*(1-c1)/DIPf1 + u[i,j,6]*PCC2 - u[i,j,4]*PRG1 - u[i,j,4]*(PCC1) - u[i,j,4]*SCRN - 0.1*u[i,j,4]
du[i,j,5] = u[i,j,3]*(1-c2)/DIPf2 + u[i,j,7]*PCC2 - u[i,j,5]*PRG1 - u[i,j,5]*(PCC1)  - u[i,j,5]*SCRN - 0.1*u[i,j,5]
du[i,j,6] = u[i,j,4]*PRG1 - u[i,j,6]*PCC2 - u[i,j,6]*SCRN - u[i,j,6]*PRG2 - 0.1*u[i,j,6]
du[i,j,7] = u[i,j,5]*PRG1 - u[i,j,7]*PCC2 - u[i,j,7]*SCRN - u[i,j,7]*PRG2  - 0.1*u[i,j,7]
du[i,j,8] = (u[i,j,6]+u[i,j,7])*PRG2 - u[i,j,8]*SCRN -u[i,j,8]/DC  - 0.1*u[i,j,8]
du[i,j,9] = (u[i,j,2]+u[i,j,3])*SCRN - u[i,j,9]/DTR - 0.1*u[i,j,9]
du[i,j,10] = (u[i,j,4]+u[i,j,5])*SCRN - u[i,j,10]/DTR - 0.1*u[i,j,10]
du[i,j,11] = (u[i,j,6]+u[i,j,7])*SCRN - u[i,j,11]/DTR - 0.1*u[i,j,11]
du[i,j,12] = u[i,j,8]*SCRN  - u[i,j,12]/DTR - 0.1*u[i,j,12]
du[i,j,13] = u[i,j,2]*(c1)/DIPf1 - u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,13]/DIf  - 0.1*u[i,j,13]
du[i,j,14] = u[i,j,3]*(c2)/DIPf2 - u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,14]/DIf - 0.1*u[i,j,14]
du[i,j,15] = (u[i,j,18]+u[i,j,19])/DIm - u[i,j,15]*Beta_m*actcmult2[i,j]*((u[i,j,3]/Ngsa[1,i,j])+(u[i,j,2]/Ngsa[1,i,j])) - 0.1*u[i,j,15] + 0.05*Rsm1[i]*(sum(u[:,4,:]))
du[i,j,16] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) +  u[i,j,19]*Beta*actcmult[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - u[i,j,16]*cm1/DIPm1 - 0.1*u[i,j,16]
du[i,j,17] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) +  u[i,j,18]*Beta*actcmult[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - u[i,j,17]*cm2/DIPm2 - 0.1*u[i,j,17]
du[i,j,18] = u[i,j,16]*cm1/DIPm1 - u[i,j,18]/DIm - u[i,j,18]*Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - 0.1*u[i,j,18]
du[i,j,19] = u[i,j,17]*cm2/DIPm2 - u[i,j,19]/DIm - u[i,j,19]*Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - 0.1*u[i,j,19]

end

@inbounds for j in 2:a, i in 1:s
du[i,j,1] = (u[i,j,13]+u[i,j,14])/DIf - u[i,j,1]*Beta*actcmult[i,j]*((u[i,j,16]/Ngsa[2,i,j])+(u[i,j,17]/Ngsa[2,i,j])) + u[i,j,8]/DC + u[i,j,9]/DTR + u[i,j,10]/DTR + u[i,j,11]/DTR + u[i,j,12]/DTR - 0.1*u[i,j,1] + 0.1*u[i,j-1,1]
du[i,j,2] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) + u[i,j,4]*(PCC1)  + u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,2]/DIPf1 - u[i,j,2]*SCRN- 0.1*u[i,j,2] + 0.1*u[i,j-1,2]
du[i,j,3] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) + u[i,j,5]*(PCC1) + u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,3]/DIPf2 - u[i,j,3]*SCRN- 0.1*u[i,j,3] + 0.1*u[i,j-1,3]
du[i,j,4] = u[i,j,2]*(1-c1)/DIPf1 + u[i,j,6]*PCC2 - u[i,j,4]*PRG1 - u[i,j,4]*(PCC1) - u[i,j,4]*SCRN - 0.1*u[i,j,4] + 0.1*u[i,j-1,4]
du[i,j,5] = u[i,j,3]*(1-c2)/DIPf2 + u[i,j,7]*PCC2 - u[i,j,5]*PRG1 - u[i,j,5]*(PCC1)- u[i,j,5]*SCRN - 0.1*u[i,j,5] + 0.1*u[i,j-1,5]
du[i,j,6] = u[i,j,4]*PRG1 - u[i,j,6]*PCC2 - u[i,j,6]*SCRN - u[i,j,6]*PRG2 - 0.1*u[i,j,6] + 0.1*u[i,j-1,6]
du[i,j,7] = u[i,j,5]*PRG1 - u[i,j,7]*PCC2 - u[i,j,7]*SCRN - u[i,j,7]*PRG2 - 0.1*u[i,j,7] + 0.1*u[i,j-1,7]
du[i,j,8] = (u[i,j,6]+u[i,j,7])*PRG2 - u[i,j,8]*SCRN -u[i,j,8]/DC  - 0.1*u[i,j,8] + 0.1*u[i,j-1,8]
du[i,j,9] = (u[i,j,2]+u[i,j,3])*SCRN - u[i,j,9]/DTR - 0.1*u[i,j,9] + 0.1*u[i,j-1,9]
du[i,j,10] = (u[i,j,4]+u[i,j,5])*SCRN - u[i,j,10]/DTR - 0.1*u[i,j,10] + 0.1*u[i,j-1,10]
du[i,j,11] = (u[i,j,6]+u[i,j,7])*SCRN - u[i,j,11]/DTR - 0.1*u[i,j,11] + 0.1*u[i,j-1,11]
du[i,j,12] = u[i,j,8]*SCRN  - u[i,j,12]/DTR - 0.1*u[i,j,12] + 0.1*u[i,j-1,12]
du[i,j,13] = u[i,j,2]*(c1)/DIPf1 - u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,13]/DIf - 0.1*u[i,j,13] + 0.1*u[i,j-1,13]
du[i,j,14] = u[i,j,3]*(c2)/DIPf2 - u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,14]/DIf - 0.1*u[i,j,14] + 0.1*u[i,j-1,14]
du[i,j,15] = (u[i,j,18]+u[i,j,19])/DIm - u[i,j,15]*Beta_m*actcmult2[i,j]*((u[i,j,3]/Ngsa[1,i,j])+(u[i,j,2]/Ngsa[1,i,j])) - 0.1*u[i,j,15]+ 0.1*u[i,j-1,15]
du[i,j,16] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) +  u[i,j,19]*Beta*actcmult[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - u[i,j,16]*cm1/DIPm1 - 0.1*u[i,j,16]+ 0.1*u[i,j-1,16]
du[i,j,17] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) +  u[i,j,18]*Beta*actcmult[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - u[i,j,17]*cm2/DIPm2 - 0.1*u[i,j,17]+ 0.1*u[i,j-1,17]
du[i,j,18] = u[i,j,16]*cm1/DIPm1 - u[i,j,18]/DIm - u[i,j,18]*Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - 0.1*u[i,j,18] + 0.1*u[i,j-1,18]
du[i,j,19] = u[i,j,17]*cm2/DIPm2 - u[i,j,19]/DIm - u[i,j,19]*Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - 0.1*u[i,j,19] + 0.1*u[i,j-1,19]
end

end

function init_mevotoar(js,ka)
J = length(js)
K = length(ka)
u = zeros(J, K, 19)
u[:,:,1] = (1-prevf1-prevf2)* Ngsa[1,:,:]
u[:,:,2] = prevf1* Ngsa[1,:,:]
u[:,:,3] = prevf2* Ngsa[1,:,:]
u[:,:,4] = zeros(J,K)
u[:,:,5] = zeros(J,K)
u[:,:,6] = zeros(J,K)
u[:,:,7] = zeros(J,K)
u[:,:,8] = zeros(J,K)
u[:,:,9] = zeros(J,K)
u[:,:,10] = zeros(J,K)
u[:,:,11] = zeros(J,K)
u[:,:,12] = zeros(J,K)
u[:,:,13] = zeros(J,K)
u[:,:,14] = zeros(J,K)
u[:,:,15] = (1-prevm1-prevm2)* Ngsa[2,:,:]
u[:,:,16] = prevm1* Ngsa[2,:,:]
u[:,:,17] = prevm2* Ngsa[2,:,:]
u[:,:,18] = zeros(J,K)
u[:,:,19] = zeros(J,K)
u
end

const u0 = init_mevotoar(sriskgr,agegr)

p = [25.0,1.425,0.986,0.81,0.93,0.10774,0.0308,0.05146,0.00508,25.0,1.016,0.525,0.4,0.4,0.04,8.0,10.0,0.670,0.538]
tspan = (0.0,5.0)

prob = ODEProblem(ODEHPV,u0,tspan,p)
sol = solve(prob,Vern7())

data_01s = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
data_02s = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
data_03s = Float64[0.01940,0.01940,0.0194,0.0194,0.0194,0.0194]
data_04s = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
data_01 = Float64[0.0,1500.0,3750.0,5600.0,7200.0,8000.0]
data_02 = Float64[0.0,615.0,1275.0,1990.0,2860.0,3650.0]
data_03 = Float64[0.0,120.0,240.0,360.0,480.0,600.0]
data_04 = Float64[0.0,15.0,30.0,45.0,60.0,75.0]
datas1 = reshape(data_01s, (1, 6))
datas2 = reshape(data_02s, (1, 6))
datas3 = reshape(data_03s, (1, 6))
datas4 = reshape(data_04s, (1, 6))
data_01 = reshape(data_01, (1, 6))
data_02 = reshape(data_02, (1, 6))
data_03 = reshape(data_03, (1, 6))
data_04 = reshape(data_04, (1, 6))

datasf1 = Array{Float64}(undef, 3,4,6)
datasf2 = Array{Float64}(undef, 3,4,6)
datasm1 = Array{Float64}(undef, 3,4,6)
datasm2 = Array{Float64}(undef, 3,4,6)

dataincf01 = Array{Float64}(undef, 3,4,6)
dataincf02 = Array{Float64}(undef, 3,4,6)
dataincm01 = Array{Float64}(undef, 3,4,6)
dataincm02 = Array{Float64}(undef, 3,4,6)

for i in 1:3, j in 1:4
datasf1[i,j,:] = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
datasf2[i,j,:] = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
datasm1[i,j,:] = Float64[0.01940,0.01940,0.01940,0.01940,0.01940,0.01940]
datasm2[i,j,:] = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
dataincf01[i,j,:] = Float64[0.0,1500.0,3750.0,5600.0,7200.0,8000.0]
dataincf02[i,j,:] = Float64[0.0,615.0,1275.0,1990.0,2860.0,3650.0]
dataincm01[i,j,:] = Float64[0.0,120.0,240.0,360.0,480.0,600.0]
dataincm02[i,j,:] = Float64[0.0,15.0,30.0,45.0,60.0,75.0]
end

data01 = Array{Float64}(undef, 3,4,8,6)
prediction = Array{Float64}(undef, 3,4,8,6)

data01[:,:,1,:] = datasf1
data01[:,:,2,:] = datasf2
data01[:,:,3,:] = datasm1
data01[:,:,4,:] = datasm2
data01[:,:,5,:] = dataincf01
data01[:,:,6,:] = dataincf02
data01[:,:,7,:] = dataincm01
data01[:,:,8,:] = dataincm02

ts = Float64[0.0,1.0,2.0,3.0,4.0,5.0]

Js = map(u->I + 0.1*ForwardDiff.jacobian(UJacobianWrapper(ODEHPV, 0.0, p), u), sol.u)

p = p
_prob = remake(prob,p=p)
estimated01 = odedata[:,:,2:3,:] ./ NF
estimated02 = odedata[:,:,16:17,:] ./ NM
estimated03 = odedata[:,:,9:12,:]
prediction[:,:,1:2,:] = estimated01
prediction[:,:,3:4,:] = estimated02
prediction[:,:,5:8,:] = estimated03
prediction
end

diff = map((u,data) -> abs2.(u .- data) , prediction, data01)
loss = sum(abs, sum(diff)) |> sqrt
loss, prediction
end

cb = function (p,l,pred) #callback function to observe training
println("Loss: \$l")
println("Parameters: \$(p)")
# using `remake` to re-create our `prob` with current parameters `p`
plot(solve(remake(prob, p=p), Vern7())) |> display
return false # Tell it to not halt the optimization. If return true, then optimization stops
end

initp = [25.0,1.425,0.986,0.81,0.93,0.10774,0.0308,0.05146,0.00508,25.0,1.016,0.525,0.4,0.4,0.04,8.0,10.0,0.670,0.538]

res2 = DiffEqFlux.sciml_train(loss_adjoint, res.u, BFGS(), cb = cb, maxiters = 1500000, allow_f_increases=true)

``````
1 Like

Did you try multiple shooting?

https://diffeqflux.sciml.ai/dev/examples/multiple_shooting/

Or successive growth of intervals

https://diffeqflux.sciml.ai/dev/examples/local_minima/

?

Also, decreasing tolerances helps to improve gradient accuracy.