Parameter Estimation of HPV Disease Model ( DiffEqFlux ) (beginner issue)


I 'm trying to estimate HPV transmission model parameters with DiffeqFlux but seems like it gets stuck in local optima and I’m not sure about using mini batching because my observation data just has length of 6 in T . I started to work by studying the tutorial on parameter estimation of stiff systems in documentation so f increases are also allowed.I’m really stuck with that issue and I don’t even know where to start any help will make my day.


using DifferentialEquations, DiffEqFlux, LinearAlgebra
using ForwardDiff
using DiffEqBase: UJacobianWrapper
using Plots
using Random
const g = 2
const s = 3
const s_ = 3
const a = 4
const a_ = 4
const gendergr = range(0,stop=1,length=g)
const sriskgr = range(0,stop=1,length=s)
const agegr = range(0,stop=1,length=a)
const TF = 6
const timefinale = range(0,stop=1,length=TF)
const prevf1 = 0.0746
const prevf2 = 0.0048
const prevm1 = 0.0194
const prevm2 = 0.0194
const Rsf1 = [0.475,0.3675,0.1575]
const Rsf2 = [0.809,0.1337,0.0573]
const Rsf3 = [0.929,0.0497,0.0213]
const Rsf4 = [0.955,0.0315,0.0135]
const Rsf5 = [0.979,0.0147,0.0063]
const Rsf6 = [0.971,0.0203,0.0087]
const Rsf7 = [0.976,0.0168,0.0072]
const Rsf8 = [0.9808,0.01344,0.00576]
const Rsm1 = [0.0986,0.63098,0.27042]
const Rsm2 = [0.5254,0.33222,0.14238]
const Rsm3 = [0.821,0.1253,0.0537]
const Rsm4 = [0.9201,0.05593,0.02397]
const Rsm5 = [0.9628,0.026,0.0112]
const Rsm6 = [0.971,0.0203,0.0087]
const Rsm7 = [0.9807,0.0135,0.0058]
const Rsm8 = [0.9907,0.00651,0.00279]
const coverall = 0.456
const Nfemale = [3211000.0, 3198000.0, 3100000.0, 3131000.0, 3269000.0, 2859000.0, 2499000.0, 3617000.0]
const Nmale = [3371000.0, 3303000.0, 3179000.0, 3201000.0, 3316000.0, 2888000.0, 2536000.0, 3425000.0]
const NF = sum(Nfemale)
const NM = sum(Nmale)

function init_probmat(dz,js,ka)

  Z = length(dz)
  L = length(js)
  J = length(js)
  K = length(ka)
  M = length(ka)

  probmat = zeros(Z,J,L,K,M)
  probmat[1,1,1,:,:] = I(K)
  probmat[1,2,2,:,:] = I(K)
  probmat[1,3,3,:,:] = I(K)
  probmat[2,1,1,:,:] = I(K)
  probmat[2,2,2,:,:] = I(K)
  probmat[2,3,3,:,:] = I(K)
  probmat[1,2,3,:,:] = I(K)
  probmat[1,3,2,:,:] = I(K)
  probmat[2,2,3,:,:] = I(K)
  probmat[2,3,2,:,:] = I(K)
  probmat .= ifelse.(isnan.(probmat), 0, probmat)


const probmat = init_probmat(gendergr,sriskgr,agegr)

function init_ngsa(dz,js,ka)

  Z = length(dz)
  J = length(js)
  K = length(ka)
  Ngsa = zeros(Z,J,K)
  Ngsa[1,:,1] = (Rsf1* Nfemale[1]) + (Rsf2* Nfemale[2])
  Ngsa[1,:,2] = (Rsf3* Nfemale[3]) + (Rsf4* Nfemale[4])
  Ngsa[1,:,3] = (Rsf5* Nfemale[5]) + (Rsf6* Nfemale[6])
  Ngsa[1,:,4] = (Rsf7* Nfemale[7]) + (Rsf8* Nfemale[8])
  Ngsa[2,:,1] = (Rsm1* Nmale[1]) + (Rsm2* Nmale[2])
  Ngsa[2,:,2] = (Rsm3* Nmale[3]) + (Rsm4* Nmale[4])
  Ngsa[2,:,3] = (Rsm5* Nmale[5]) + (Rsm6* Nmale[6])
  Ngsa[2,:,4] = (Rsm7* Nmale[7]) + (Rsm8* Nmale[8])


function init_rgsa(dz,js,ka)

  J = length(js)
  G = length(dz)
  K = length(ka)
  Rgsa = zeros(G,J,K)
  Rfsa = zeros(J,K)
  Rmsa = zeros(J,K)
  Rs = [1.0,3.0,5.0]
  Ra = [4.0,3.0,2.0,1.0]
  Rfsa = Rs*transpose(Ra)
  Rmsa = Rs*transpose(Ra)
  Rgsa[1,:,:] = Rfsa
  Rgsa[2,:,:] = Rmsa

const Ngsa = init_ngsa(gendergr,sriskgr,agegr)
const Rgsa = init_rgsa(gendergr,sriskgr,agegr)
const Ngsa_Rgsa = Ngsa.*Rgsa

const cmin = coverall*(sum(Ngsa)/sum(Ngsa_Rgsa))
const Cgsa = cmin .* Rgsa

function init_balmat(Cgsa,probmat,Ngsa)

  balmat = zeros(s,s_,a,a_)

  @inbounds for i in 1:s,j in 1:a,k in 1:s_,l in 1:a_
      balmat[i,k,j,l] = (Cgsa[1,i,j]*probmat[1,i,k,j,l]*Ngsa[1,i,j]) / (Cgsa[2,i,j]*probmat[2,i,k,j,l]*Ngsa[2,i,j])
  balmat .= ifelse.(isnan.(balmat), 0, balmat)

const balmat = init_balmat(Cgsa,probmat,Ngsa)

function init_Cmat(Cgsa,balmat)

  Cmat = zeros(g,s,s_,a,a_)

  @inbounds for i in 1:s,j in 1:a,k in 1:s_,l in 1:a_
      Cmat[1,i,k,j,l] = Cgsa[1,i,j]*balmat[i,k,j,l]^0
      Cmat[2,i,k,j,l] = Cgsa[2,i,j]*balmat[i,k,j,l]^(1)


const Cmat = init_Cmat(Cgsa,balmat)

actcmult_ = Cmat[1,:,:,:,:]
actcmult_2 = Cmat[2,:,:,:,:]
const cstar = 1.0
actcnew = mapslices(sum, actcmult_, dims = [2,4])
actcnew2 = dropdims(actcnew; dims=2)
actcnew3 = dropdims(actcnew2; dims=3)
actcmult =  cstar * actcnew3

actcnew2 = mapslices(sum, actcmult_2, dims = [2,4])
actcnew22 = dropdims(actcnew; dims=2)
actcnew23 = dropdims(actcnew22; dims=3)
actcmult2 =  cstar * actcnew23

function ODEHPV(du,u,p,t)
  DIf, DIPf1, DIPf2, c1, c2, PCC1, PCC2, PRG1, PRG2,DIm, DIPm1, DIPm2, Beta, Beta_m, SCRN, DTR, DC, cm1, cm2  = p

  @inbounds for j = 1, i in 1:s
    du[i,j,1] = (u[i,j,13]+u[i,j,14])/DIf - u[i,j,1]*Beta*actcmult[i,j]*((u[i,j,16]/Ngsa[2,i,j])+(u[i,j,17]/Ngsa[2,i,j])) + u[i,j,8]/DC + u[i,j,9]/DTR + u[i,j,10]/DTR + u[i,j,11]/DTR + u[i,j,12]/DTR - 0.1*u[i,j,1] + 0.05*Rsf1[i]*sum(u[:,4,:])
    du[i,j,2] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j])  + u[i,j,4]*(PCC1) + u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,2]/DIPf1 - u[i,j,2]*SCRN - 0.1*u[i,j,2]
    du[i,j,3] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j])  + u[i,j,5]*(PCC1) + u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,3]/DIPf2 - u[i,j,3]*SCRN- 0.1*u[i,j,3]
    du[i,j,4] = u[i,j,2]*(1-c1)/DIPf1 + u[i,j,6]*PCC2 - u[i,j,4]*PRG1 - u[i,j,4]*(PCC1) - u[i,j,4]*SCRN - 0.1*u[i,j,4]
    du[i,j,5] = u[i,j,3]*(1-c2)/DIPf2 + u[i,j,7]*PCC2 - u[i,j,5]*PRG1 - u[i,j,5]*(PCC1)  - u[i,j,5]*SCRN - 0.1*u[i,j,5]
    du[i,j,6] = u[i,j,4]*PRG1 - u[i,j,6]*PCC2 - u[i,j,6]*SCRN - u[i,j,6]*PRG2 - 0.1*u[i,j,6]
    du[i,j,7] = u[i,j,5]*PRG1 - u[i,j,7]*PCC2 - u[i,j,7]*SCRN - u[i,j,7]*PRG2  - 0.1*u[i,j,7]
    du[i,j,8] = (u[i,j,6]+u[i,j,7])*PRG2 - u[i,j,8]*SCRN -u[i,j,8]/DC  - 0.1*u[i,j,8]
    du[i,j,9] = (u[i,j,2]+u[i,j,3])*SCRN - u[i,j,9]/DTR - 0.1*u[i,j,9]
    du[i,j,10] = (u[i,j,4]+u[i,j,5])*SCRN - u[i,j,10]/DTR - 0.1*u[i,j,10]
    du[i,j,11] = (u[i,j,6]+u[i,j,7])*SCRN - u[i,j,11]/DTR - 0.1*u[i,j,11]
    du[i,j,12] = u[i,j,8]*SCRN  - u[i,j,12]/DTR - 0.1*u[i,j,12]
    du[i,j,13] = u[i,j,2]*(c1)/DIPf1 - u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,13]/DIf  - 0.1*u[i,j,13]
    du[i,j,14] = u[i,j,3]*(c2)/DIPf2 - u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,14]/DIf - 0.1*u[i,j,14]
    du[i,j,15] = (u[i,j,18]+u[i,j,19])/DIm - u[i,j,15]*Beta_m*actcmult2[i,j]*((u[i,j,3]/Ngsa[1,i,j])+(u[i,j,2]/Ngsa[1,i,j])) - 0.1*u[i,j,15] + 0.05*Rsm1[i]*(sum(u[:,4,:]))
    du[i,j,16] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) +  u[i,j,19]*Beta*actcmult[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - u[i,j,16]*cm1/DIPm1 - 0.1*u[i,j,16]
    du[i,j,17] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) +  u[i,j,18]*Beta*actcmult[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - u[i,j,17]*cm2/DIPm2 - 0.1*u[i,j,17]
    du[i,j,18] = u[i,j,16]*cm1/DIPm1 - u[i,j,18]/DIm - u[i,j,18]*Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - 0.1*u[i,j,18]
    du[i,j,19] = u[i,j,17]*cm2/DIPm2 - u[i,j,19]/DIm - u[i,j,19]*Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - 0.1*u[i,j,19]


  @inbounds for j in 2:a, i in 1:s
    du[i,j,1] = (u[i,j,13]+u[i,j,14])/DIf - u[i,j,1]*Beta*actcmult[i,j]*((u[i,j,16]/Ngsa[2,i,j])+(u[i,j,17]/Ngsa[2,i,j])) + u[i,j,8]/DC + u[i,j,9]/DTR + u[i,j,10]/DTR + u[i,j,11]/DTR + u[i,j,12]/DTR - 0.1*u[i,j,1] + 0.1*u[i,j-1,1]
    du[i,j,2] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) + u[i,j,4]*(PCC1)  + u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,2]/DIPf1 - u[i,j,2]*SCRN- 0.1*u[i,j,2] + 0.1*u[i,j-1,2]
    du[i,j,3] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) + u[i,j,5]*(PCC1) + u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,3]/DIPf2 - u[i,j,3]*SCRN- 0.1*u[i,j,3] + 0.1*u[i,j-1,3]
    du[i,j,4] = u[i,j,2]*(1-c1)/DIPf1 + u[i,j,6]*PCC2 - u[i,j,4]*PRG1 - u[i,j,4]*(PCC1) - u[i,j,4]*SCRN - 0.1*u[i,j,4] + 0.1*u[i,j-1,4]
    du[i,j,5] = u[i,j,3]*(1-c2)/DIPf2 + u[i,j,7]*PCC2 - u[i,j,5]*PRG1 - u[i,j,5]*(PCC1)- u[i,j,5]*SCRN - 0.1*u[i,j,5] + 0.1*u[i,j-1,5]
    du[i,j,6] = u[i,j,4]*PRG1 - u[i,j,6]*PCC2 - u[i,j,6]*SCRN - u[i,j,6]*PRG2 - 0.1*u[i,j,6] + 0.1*u[i,j-1,6]
    du[i,j,7] = u[i,j,5]*PRG1 - u[i,j,7]*PCC2 - u[i,j,7]*SCRN - u[i,j,7]*PRG2 - 0.1*u[i,j,7] + 0.1*u[i,j-1,7]
    du[i,j,8] = (u[i,j,6]+u[i,j,7])*PRG2 - u[i,j,8]*SCRN -u[i,j,8]/DC  - 0.1*u[i,j,8] + 0.1*u[i,j-1,8]
    du[i,j,9] = (u[i,j,2]+u[i,j,3])*SCRN - u[i,j,9]/DTR - 0.1*u[i,j,9] + 0.1*u[i,j-1,9]
    du[i,j,10] = (u[i,j,4]+u[i,j,5])*SCRN - u[i,j,10]/DTR - 0.1*u[i,j,10] + 0.1*u[i,j-1,10]
    du[i,j,11] = (u[i,j,6]+u[i,j,7])*SCRN - u[i,j,11]/DTR - 0.1*u[i,j,11] + 0.1*u[i,j-1,11]
    du[i,j,12] = u[i,j,8]*SCRN  - u[i,j,12]/DTR - 0.1*u[i,j,12] + 0.1*u[i,j-1,12]
    du[i,j,13] = u[i,j,2]*(c1)/DIPf1 - u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,13]/DIf - 0.1*u[i,j,13] + 0.1*u[i,j-1,13]
    du[i,j,14] = u[i,j,3]*(c2)/DIPf2 - u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,14]/DIf - 0.1*u[i,j,14] + 0.1*u[i,j-1,14]
    du[i,j,15] = (u[i,j,18]+u[i,j,19])/DIm - u[i,j,15]*Beta_m*actcmult2[i,j]*((u[i,j,3]/Ngsa[1,i,j])+(u[i,j,2]/Ngsa[1,i,j])) - 0.1*u[i,j,15]+ 0.1*u[i,j-1,15]
    du[i,j,16] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) +  u[i,j,19]*Beta*actcmult[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - u[i,j,16]*cm1/DIPm1 - 0.1*u[i,j,16]+ 0.1*u[i,j-1,16]
    du[i,j,17] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) +  u[i,j,18]*Beta*actcmult[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - u[i,j,17]*cm2/DIPm2 - 0.1*u[i,j,17]+ 0.1*u[i,j-1,17]
    du[i,j,18] = u[i,j,16]*cm1/DIPm1 - u[i,j,18]/DIm - u[i,j,18]*Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - 0.1*u[i,j,18] + 0.1*u[i,j-1,18]
    du[i,j,19] = u[i,j,17]*cm2/DIPm2 - u[i,j,19]/DIm - u[i,j,19]*Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - 0.1*u[i,j,19] + 0.1*u[i,j-1,19]


function init_mevotoar(js,ka)
  J = length(js)
  K = length(ka)
  u = zeros(J, K, 19)
  u[:,:,1] = (1-prevf1-prevf2)* Ngsa[1,:,:]
  u[:,:,2] = prevf1* Ngsa[1,:,:]
  u[:,:,3] = prevf2* Ngsa[1,:,:]
  u[:,:,4] = zeros(J,K)
  u[:,:,5] = zeros(J,K)
  u[:,:,6] = zeros(J,K)
  u[:,:,7] = zeros(J,K)
  u[:,:,8] = zeros(J,K)
  u[:,:,9] = zeros(J,K)
  u[:,:,10] = zeros(J,K)
  u[:,:,11] = zeros(J,K)
  u[:,:,12] = zeros(J,K)
  u[:,:,13] = zeros(J,K)
  u[:,:,14] = zeros(J,K)
  u[:,:,15] = (1-prevm1-prevm2)* Ngsa[2,:,:]
  u[:,:,16] = prevm1* Ngsa[2,:,:]
  u[:,:,17] = prevm2* Ngsa[2,:,:]
  u[:,:,18] = zeros(J,K)
  u[:,:,19] = zeros(J,K)

const u0 = init_mevotoar(sriskgr,agegr)

p = [25.0,1.425,0.986,0.81,0.93,0.10774,0.0308,0.05146,0.00508,25.0,1.016,0.525,0.4,0.4,0.04,8.0,10.0,0.670,0.538]
tspan = (0.0,5.0)

prob = ODEProblem(ODEHPV,u0,tspan,p)
sol = solve(prob,Vern7())

data_01s = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
data_02s = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
data_03s = Float64[0.01940,0.01940,0.0194,0.0194,0.0194,0.0194]
data_04s = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
data_01 = Float64[0.0,1500.0,3750.0,5600.0,7200.0,8000.0]
data_02 = Float64[0.0,615.0,1275.0,1990.0,2860.0,3650.0]
data_03 = Float64[0.0,120.0,240.0,360.0,480.0,600.0]
data_04 = Float64[0.0,15.0,30.0,45.0,60.0,75.0]
datas1 = reshape(data_01s, (1, 6))
datas2 = reshape(data_02s, (1, 6))
datas3 = reshape(data_03s, (1, 6))
datas4 = reshape(data_04s, (1, 6))
data_01 = reshape(data_01, (1, 6))
data_02 = reshape(data_02, (1, 6))
data_03 = reshape(data_03, (1, 6))
data_04 = reshape(data_04, (1, 6))

datasf1 = Array{Float64}(undef, 3,4,6)
datasf2 = Array{Float64}(undef, 3,4,6)
datasm1 = Array{Float64}(undef, 3,4,6)
datasm2 = Array{Float64}(undef, 3,4,6)

dataincf01 = Array{Float64}(undef, 3,4,6)
dataincf02 = Array{Float64}(undef, 3,4,6)
dataincm01 = Array{Float64}(undef, 3,4,6)
dataincm02 = Array{Float64}(undef, 3,4,6)

for i in 1:3, j in 1:4
    datasf1[i,j,:] = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
    datasf2[i,j,:] = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
    datasm1[i,j,:] = Float64[0.01940,0.01940,0.01940,0.01940,0.01940,0.01940]
    datasm2[i,j,:] = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
    dataincf01[i,j,:] = Float64[0.0,1500.0,3750.0,5600.0,7200.0,8000.0]
    dataincf02[i,j,:] = Float64[0.0,615.0,1275.0,1990.0,2860.0,3650.0]
    dataincm01[i,j,:] = Float64[0.0,120.0,240.0,360.0,480.0,600.0]
    dataincm02[i,j,:] = Float64[0.0,15.0,30.0,45.0,60.0,75.0]

data01 = Array{Float64}(undef, 3,4,8,6)
prediction = Array{Float64}(undef, 3,4,8,6)

data01[:,:,1,:] = datasf1
data01[:,:,2,:] = datasf2
data01[:,:,3,:] = datasm1
data01[:,:,4,:] = datasm2
data01[:,:,5,:] = dataincf01
data01[:,:,6,:] = dataincf02
data01[:,:,7,:] = dataincm01
data01[:,:,8,:] = dataincm02

ts = Float64[0.0,1.0,2.0,3.0,4.0,5.0]

Js = map(u->I + 0.1*ForwardDiff.jacobian(UJacobianWrapper(ODEHPV, 0.0, p), u), sol.u)

function predict_adjoint(p)
    p = p
    _prob = remake(prob,p=p)
    odedata = Array(solve(_prob, Vern7(),saveat=ts,sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
    estimated01 = odedata[:,:,2:3,:] ./ NF
    estimated02 = odedata[:,:,16:17,:] ./ NM
    estimated03 = odedata[:,:,9:12,:]
    prediction[:,:,1:2,:] = estimated01
    prediction[:,:,3:4,:] = estimated02
    prediction[:,:,5:8,:] = estimated03

function loss_adjoint(p)
    prediction = predict_adjoint(p)
    diff = map((u,data) -> abs2.(u .- data) , prediction, data01)
    loss = sum(abs, sum(diff)) |> sqrt
    loss, prediction

cb = function (p,l,pred) #callback function to observe training
    println("Loss: $l")
    println("Parameters: $(p)")
    # using `remake` to re-create our `prob` with current parameters `p`
    plot(solve(remake(prob, p=p), Vern7())) |> display
    return false # Tell it to not halt the optimization. If return true, then optimization stops

initp = [25.0,1.425,0.986,0.81,0.93,0.10774,0.0308,0.05146,0.00508,25.0,1.016,0.525,0.4,0.4,0.04,8.0,10.0,0.670,0.538]


res = DiffEqFlux.sciml_train(loss_adjoint, initp, ADAM(0.01), cb = cb, maxiters = 150000)
res2 = DiffEqFlux.sciml_train(loss_adjoint, res.u, BFGS(), cb = cb, maxiters = 1500000, allow_f_increases=true)

1 Like

Did you try multiple shooting?

Or successive growth of intervals


Also, decreasing tolerances helps to improve gradient accuracy.