Parameter Estimation of HPV Disease Model ( DiffEqFlux ) (beginner issue)

Hi,

I 'm trying to estimate HPV transmission model parameters with DiffeqFlux but seems like it gets stuck in local optima and I’m not sure about using mini batching because my observation data just has length of 6 in T . I started to work by studying the tutorial on parameter estimation of stiff systems in documentation so f increases are also allowed.I’m really stuck with that issue and I don’t even know where to start any help will make my day.

Thanks.
Mevlüt.


using DifferentialEquations, DiffEqFlux, LinearAlgebra
using ForwardDiff
using DiffEqBase: UJacobianWrapper
using Plots
using Random
Random.seed!(15000)
const g = 2
const s = 3
const s_ = 3
const a = 4
const a_ = 4
const gendergr = range(0,stop=1,length=g)
const sriskgr = range(0,stop=1,length=s)
const agegr = range(0,stop=1,length=a)
const TF = 6
const timefinale = range(0,stop=1,length=TF)
const prevf1 = 0.0746
const prevf2 = 0.0048
const prevm1 = 0.0194
const prevm2 = 0.0194
const Rsf1 = [0.475,0.3675,0.1575]
const Rsf2 = [0.809,0.1337,0.0573]
const Rsf3 = [0.929,0.0497,0.0213]
const Rsf4 = [0.955,0.0315,0.0135]
const Rsf5 = [0.979,0.0147,0.0063]
const Rsf6 = [0.971,0.0203,0.0087]
const Rsf7 = [0.976,0.0168,0.0072]
const Rsf8 = [0.9808,0.01344,0.00576]
const Rsm1 = [0.0986,0.63098,0.27042]
const Rsm2 = [0.5254,0.33222,0.14238]
const Rsm3 = [0.821,0.1253,0.0537]
const Rsm4 = [0.9201,0.05593,0.02397]
const Rsm5 = [0.9628,0.026,0.0112]
const Rsm6 = [0.971,0.0203,0.0087]
const Rsm7 = [0.9807,0.0135,0.0058]
const Rsm8 = [0.9907,0.00651,0.00279]
const coverall = 0.456
const Nfemale = [3211000.0, 3198000.0, 3100000.0, 3131000.0, 3269000.0, 2859000.0, 2499000.0, 3617000.0]
const Nmale = [3371000.0, 3303000.0, 3179000.0, 3201000.0, 3316000.0, 2888000.0, 2536000.0, 3425000.0]
const NF = sum(Nfemale)
const NM = sum(Nmale)

function init_probmat(dz,js,ka)

  Z = length(dz)
  L = length(js)
  J = length(js)
  K = length(ka)
  M = length(ka)

  probmat = zeros(Z,J,L,K,M)
  probmat[1,1,1,:,:] = I(K)
  probmat[1,2,2,:,:] = I(K)
  probmat[1,3,3,:,:] = I(K)
  probmat[2,1,1,:,:] = I(K)
  probmat[2,2,2,:,:] = I(K)
  probmat[2,3,3,:,:] = I(K)
  probmat[1,2,3,:,:] = I(K)
  probmat[1,3,2,:,:] = I(K)
  probmat[2,2,3,:,:] = I(K)
  probmat[2,3,2,:,:] = I(K)
  probmat .= ifelse.(isnan.(probmat), 0, probmat)


  probmat
end

const probmat = init_probmat(gendergr,sriskgr,agegr)

function init_ngsa(dz,js,ka)

  Z = length(dz)
  J = length(js)
  K = length(ka)
  Ngsa = zeros(Z,J,K)
  Ngsa[1,:,1] = (Rsf1* Nfemale[1]) + (Rsf2* Nfemale[2])
  Ngsa[1,:,2] = (Rsf3* Nfemale[3]) + (Rsf4* Nfemale[4])
  Ngsa[1,:,3] = (Rsf5* Nfemale[5]) + (Rsf6* Nfemale[6])
  Ngsa[1,:,4] = (Rsf7* Nfemale[7]) + (Rsf8* Nfemale[8])
  Ngsa[2,:,1] = (Rsm1* Nmale[1]) + (Rsm2* Nmale[2])
  Ngsa[2,:,2] = (Rsm3* Nmale[3]) + (Rsm4* Nmale[4])
  Ngsa[2,:,3] = (Rsm5* Nmale[5]) + (Rsm6* Nmale[6])
  Ngsa[2,:,4] = (Rsm7* Nmale[7]) + (Rsm8* Nmale[8])


  Ngsa
end


function init_rgsa(dz,js,ka)

  J = length(js)
  G = length(dz)
  K = length(ka)
  Rgsa = zeros(G,J,K)
  Rfsa = zeros(J,K)
  Rmsa = zeros(J,K)
  Rs = [1.0,3.0,5.0]
  Ra = [4.0,3.0,2.0,1.0]
  Rfsa = Rs*transpose(Ra)
  Rmsa = Rs*transpose(Ra)
  Rgsa[1,:,:] = Rfsa
  Rgsa[2,:,:] = Rmsa
  Rgsa
end

const Ngsa = init_ngsa(gendergr,sriskgr,agegr)
const Rgsa = init_rgsa(gendergr,sriskgr,agegr)
const Ngsa_Rgsa = Ngsa.*Rgsa


const cmin = coverall*(sum(Ngsa)/sum(Ngsa_Rgsa))
const Cgsa = cmin .* Rgsa

function init_balmat(Cgsa,probmat,Ngsa)

  balmat = zeros(s,s_,a,a_)

  @inbounds for i in 1:s,j in 1:a,k in 1:s_,l in 1:a_
      balmat[i,k,j,l] = (Cgsa[1,i,j]*probmat[1,i,k,j,l]*Ngsa[1,i,j]) / (Cgsa[2,i,j]*probmat[2,i,k,j,l]*Ngsa[2,i,j])
  end
  balmat .= ifelse.(isnan.(balmat), 0, balmat)
  balmat
end

const balmat = init_balmat(Cgsa,probmat,Ngsa)


function init_Cmat(Cgsa,balmat)

  Cmat = zeros(g,s,s_,a,a_)

  @inbounds for i in 1:s,j in 1:a,k in 1:s_,l in 1:a_
      Cmat[1,i,k,j,l] = Cgsa[1,i,j]*balmat[i,k,j,l]^0
      Cmat[2,i,k,j,l] = Cgsa[2,i,j]*balmat[i,k,j,l]^(1)

  end
  Cmat
end

const Cmat = init_Cmat(Cgsa,balmat)

actcmult_ = Cmat[1,:,:,:,:]
actcmult_2 = Cmat[2,:,:,:,:]
const cstar = 1.0
actcnew = mapslices(sum, actcmult_, dims = [2,4])
actcnew2 = dropdims(actcnew; dims=2)
actcnew3 = dropdims(actcnew2; dims=3)
actcmult =  cstar * actcnew3

actcnew2 = mapslices(sum, actcmult_2, dims = [2,4])
actcnew22 = dropdims(actcnew; dims=2)
actcnew23 = dropdims(actcnew22; dims=3)
actcmult2 =  cstar * actcnew23



function ODEHPV(du,u,p,t)
  DIf, DIPf1, DIPf2, c1, c2, PCC1, PCC2, PRG1, PRG2,DIm, DIPm1, DIPm2, Beta, Beta_m, SCRN, DTR, DC, cm1, cm2  = p





  @inbounds for j = 1, i in 1:s
    du[i,j,1] = (u[i,j,13]+u[i,j,14])/DIf - u[i,j,1]*Beta*actcmult[i,j]*((u[i,j,16]/Ngsa[2,i,j])+(u[i,j,17]/Ngsa[2,i,j])) + u[i,j,8]/DC + u[i,j,9]/DTR + u[i,j,10]/DTR + u[i,j,11]/DTR + u[i,j,12]/DTR - 0.1*u[i,j,1] + 0.05*Rsf1[i]*sum(u[:,4,:])
    du[i,j,2] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j])  + u[i,j,4]*(PCC1) + u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,2]/DIPf1 - u[i,j,2]*SCRN - 0.1*u[i,j,2]
    du[i,j,3] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j])  + u[i,j,5]*(PCC1) + u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,3]/DIPf2 - u[i,j,3]*SCRN- 0.1*u[i,j,3]
    du[i,j,4] = u[i,j,2]*(1-c1)/DIPf1 + u[i,j,6]*PCC2 - u[i,j,4]*PRG1 - u[i,j,4]*(PCC1) - u[i,j,4]*SCRN - 0.1*u[i,j,4]
    du[i,j,5] = u[i,j,3]*(1-c2)/DIPf2 + u[i,j,7]*PCC2 - u[i,j,5]*PRG1 - u[i,j,5]*(PCC1)  - u[i,j,5]*SCRN - 0.1*u[i,j,5]
    du[i,j,6] = u[i,j,4]*PRG1 - u[i,j,6]*PCC2 - u[i,j,6]*SCRN - u[i,j,6]*PRG2 - 0.1*u[i,j,6]
    du[i,j,7] = u[i,j,5]*PRG1 - u[i,j,7]*PCC2 - u[i,j,7]*SCRN - u[i,j,7]*PRG2  - 0.1*u[i,j,7]
    du[i,j,8] = (u[i,j,6]+u[i,j,7])*PRG2 - u[i,j,8]*SCRN -u[i,j,8]/DC  - 0.1*u[i,j,8]
    du[i,j,9] = (u[i,j,2]+u[i,j,3])*SCRN - u[i,j,9]/DTR - 0.1*u[i,j,9]
    du[i,j,10] = (u[i,j,4]+u[i,j,5])*SCRN - u[i,j,10]/DTR - 0.1*u[i,j,10]
    du[i,j,11] = (u[i,j,6]+u[i,j,7])*SCRN - u[i,j,11]/DTR - 0.1*u[i,j,11]
    du[i,j,12] = u[i,j,8]*SCRN  - u[i,j,12]/DTR - 0.1*u[i,j,12]
    du[i,j,13] = u[i,j,2]*(c1)/DIPf1 - u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,13]/DIf  - 0.1*u[i,j,13]
    du[i,j,14] = u[i,j,3]*(c2)/DIPf2 - u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,14]/DIf - 0.1*u[i,j,14]
    du[i,j,15] = (u[i,j,18]+u[i,j,19])/DIm - u[i,j,15]*Beta_m*actcmult2[i,j]*((u[i,j,3]/Ngsa[1,i,j])+(u[i,j,2]/Ngsa[1,i,j])) - 0.1*u[i,j,15] + 0.05*Rsm1[i]*(sum(u[:,4,:]))
    du[i,j,16] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) +  u[i,j,19]*Beta*actcmult[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - u[i,j,16]*cm1/DIPm1 - 0.1*u[i,j,16]
    du[i,j,17] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) +  u[i,j,18]*Beta*actcmult[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - u[i,j,17]*cm2/DIPm2 - 0.1*u[i,j,17]
    du[i,j,18] = u[i,j,16]*cm1/DIPm1 - u[i,j,18]/DIm - u[i,j,18]*Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - 0.1*u[i,j,18]
    du[i,j,19] = u[i,j,17]*cm2/DIPm2 - u[i,j,19]/DIm - u[i,j,19]*Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - 0.1*u[i,j,19]

  end

  @inbounds for j in 2:a, i in 1:s
    du[i,j,1] = (u[i,j,13]+u[i,j,14])/DIf - u[i,j,1]*Beta*actcmult[i,j]*((u[i,j,16]/Ngsa[2,i,j])+(u[i,j,17]/Ngsa[2,i,j])) + u[i,j,8]/DC + u[i,j,9]/DTR + u[i,j,10]/DTR + u[i,j,11]/DTR + u[i,j,12]/DTR - 0.1*u[i,j,1] + 0.1*u[i,j-1,1]
    du[i,j,2] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) + u[i,j,4]*(PCC1)  + u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,2]/DIPf1 - u[i,j,2]*SCRN- 0.1*u[i,j,2] + 0.1*u[i,j-1,2]
    du[i,j,3] = u[i,j,1]*Beta * actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) + u[i,j,5]*(PCC1) + u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,3]/DIPf2 - u[i,j,3]*SCRN- 0.1*u[i,j,3] + 0.1*u[i,j-1,3]
    du[i,j,4] = u[i,j,2]*(1-c1)/DIPf1 + u[i,j,6]*PCC2 - u[i,j,4]*PRG1 - u[i,j,4]*(PCC1) - u[i,j,4]*SCRN - 0.1*u[i,j,4] + 0.1*u[i,j-1,4]
    du[i,j,5] = u[i,j,3]*(1-c2)/DIPf2 + u[i,j,7]*PCC2 - u[i,j,5]*PRG1 - u[i,j,5]*(PCC1)- u[i,j,5]*SCRN - 0.1*u[i,j,5] + 0.1*u[i,j-1,5]
    du[i,j,6] = u[i,j,4]*PRG1 - u[i,j,6]*PCC2 - u[i,j,6]*SCRN - u[i,j,6]*PRG2 - 0.1*u[i,j,6] + 0.1*u[i,j-1,6]
    du[i,j,7] = u[i,j,5]*PRG1 - u[i,j,7]*PCC2 - u[i,j,7]*SCRN - u[i,j,7]*PRG2 - 0.1*u[i,j,7] + 0.1*u[i,j-1,7]
    du[i,j,8] = (u[i,j,6]+u[i,j,7])*PRG2 - u[i,j,8]*SCRN -u[i,j,8]/DC  - 0.1*u[i,j,8] + 0.1*u[i,j-1,8]
    du[i,j,9] = (u[i,j,2]+u[i,j,3])*SCRN - u[i,j,9]/DTR - 0.1*u[i,j,9] + 0.1*u[i,j-1,9]
    du[i,j,10] = (u[i,j,4]+u[i,j,5])*SCRN - u[i,j,10]/DTR - 0.1*u[i,j,10] + 0.1*u[i,j-1,10]
    du[i,j,11] = (u[i,j,6]+u[i,j,7])*SCRN - u[i,j,11]/DTR - 0.1*u[i,j,11] + 0.1*u[i,j-1,11]
    du[i,j,12] = u[i,j,8]*SCRN  - u[i,j,12]/DTR - 0.1*u[i,j,12] + 0.1*u[i,j-1,12]
    du[i,j,13] = u[i,j,2]*(c1)/DIPf1 - u[i,j,13]*Beta*actcmult[i,j]*(u[i,j,17]/Ngsa[2,i,j]) - u[i,j,13]/DIf - 0.1*u[i,j,13] + 0.1*u[i,j-1,13]
    du[i,j,14] = u[i,j,3]*(c2)/DIPf2 - u[i,j,14]*Beta*actcmult[i,j]*(u[i,j,16]/Ngsa[2,i,j]) - u[i,j,14]/DIf - 0.1*u[i,j,14] + 0.1*u[i,j-1,14]
    du[i,j,15] = (u[i,j,18]+u[i,j,19])/DIm - u[i,j,15]*Beta_m*actcmult2[i,j]*((u[i,j,3]/Ngsa[1,i,j])+(u[i,j,2]/Ngsa[1,i,j])) - 0.1*u[i,j,15]+ 0.1*u[i,j-1,15]
    du[i,j,16] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) +  u[i,j,19]*Beta*actcmult[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - u[i,j,16]*cm1/DIPm1 - 0.1*u[i,j,16]+ 0.1*u[i,j-1,16]
    du[i,j,17] = u[i,j,15]* Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) +  u[i,j,18]*Beta*actcmult[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - u[i,j,17]*cm2/DIPm2 - 0.1*u[i,j,17]+ 0.1*u[i,j-1,17]
    du[i,j,18] = u[i,j,16]*cm1/DIPm1 - u[i,j,18]/DIm - u[i,j,18]*Beta_m*actcmult2[i,j]*(u[i,j,3]/Ngsa[1,i,j]) - 0.1*u[i,j,18] + 0.1*u[i,j-1,18]
    du[i,j,19] = u[i,j,17]*cm2/DIPm2 - u[i,j,19]/DIm - u[i,j,19]*Beta_m*actcmult2[i,j]*(u[i,j,2]/Ngsa[1,i,j]) - 0.1*u[i,j,19] + 0.1*u[i,j-1,19]
  end

end

function init_mevotoar(js,ka)
  J = length(js)
  K = length(ka)
  u = zeros(J, K, 19)
  u[:,:,1] = (1-prevf1-prevf2)* Ngsa[1,:,:]
  u[:,:,2] = prevf1* Ngsa[1,:,:]
  u[:,:,3] = prevf2* Ngsa[1,:,:]
  u[:,:,4] = zeros(J,K)
  u[:,:,5] = zeros(J,K)
  u[:,:,6] = zeros(J,K)
  u[:,:,7] = zeros(J,K)
  u[:,:,8] = zeros(J,K)
  u[:,:,9] = zeros(J,K)
  u[:,:,10] = zeros(J,K)
  u[:,:,11] = zeros(J,K)
  u[:,:,12] = zeros(J,K)
  u[:,:,13] = zeros(J,K)
  u[:,:,14] = zeros(J,K)
  u[:,:,15] = (1-prevm1-prevm2)* Ngsa[2,:,:]
  u[:,:,16] = prevm1* Ngsa[2,:,:]
  u[:,:,17] = prevm2* Ngsa[2,:,:]
  u[:,:,18] = zeros(J,K)
  u[:,:,19] = zeros(J,K)
  u
end




const u0 = init_mevotoar(sriskgr,agegr)

p = [25.0,1.425,0.986,0.81,0.93,0.10774,0.0308,0.05146,0.00508,25.0,1.016,0.525,0.4,0.4,0.04,8.0,10.0,0.670,0.538]
tspan = (0.0,5.0)


prob = ODEProblem(ODEHPV,u0,tspan,p)
sol = solve(prob,Vern7())

data_01s = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
data_02s = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
data_03s = Float64[0.01940,0.01940,0.0194,0.0194,0.0194,0.0194]
data_04s = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
data_01 = Float64[0.0,1500.0,3750.0,5600.0,7200.0,8000.0]
data_02 = Float64[0.0,615.0,1275.0,1990.0,2860.0,3650.0]
data_03 = Float64[0.0,120.0,240.0,360.0,480.0,600.0]
data_04 = Float64[0.0,15.0,30.0,45.0,60.0,75.0]
datas1 = reshape(data_01s, (1, 6))
datas2 = reshape(data_02s, (1, 6))
datas3 = reshape(data_03s, (1, 6))
datas4 = reshape(data_04s, (1, 6))
data_01 = reshape(data_01, (1, 6))
data_02 = reshape(data_02, (1, 6))
data_03 = reshape(data_03, (1, 6))
data_04 = reshape(data_04, (1, 6))


datasf1 = Array{Float64}(undef, 3,4,6)
datasf2 = Array{Float64}(undef, 3,4,6)
datasm1 = Array{Float64}(undef, 3,4,6)
datasm2 = Array{Float64}(undef, 3,4,6)

dataincf01 = Array{Float64}(undef, 3,4,6)
dataincf02 = Array{Float64}(undef, 3,4,6)
dataincm01 = Array{Float64}(undef, 3,4,6)
dataincm02 = Array{Float64}(undef, 3,4,6)


for i in 1:3, j in 1:4
    datasf1[i,j,:] = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
    datasf2[i,j,:] = Float64[0.0048,0.0048,0.0048,0.0048,0.0048,0.0048]
    datasm1[i,j,:] = Float64[0.01940,0.01940,0.01940,0.01940,0.01940,0.01940]
    datasm2[i,j,:] = Float64[0.0746,0.0746,0.0746,0.0746,0.0746,0.0746]
    dataincf01[i,j,:] = Float64[0.0,1500.0,3750.0,5600.0,7200.0,8000.0]
    dataincf02[i,j,:] = Float64[0.0,615.0,1275.0,1990.0,2860.0,3650.0]
    dataincm01[i,j,:] = Float64[0.0,120.0,240.0,360.0,480.0,600.0]
    dataincm02[i,j,:] = Float64[0.0,15.0,30.0,45.0,60.0,75.0]
end

data01 = Array{Float64}(undef, 3,4,8,6)
prediction = Array{Float64}(undef, 3,4,8,6)

data01[:,:,1,:] = datasf1
data01[:,:,2,:] = datasf2
data01[:,:,3,:] = datasm1
data01[:,:,4,:] = datasm2
data01[:,:,5,:] = dataincf01
data01[:,:,6,:] = dataincf02
data01[:,:,7,:] = dataincm01
data01[:,:,8,:] = dataincm02


ts = Float64[0.0,1.0,2.0,3.0,4.0,5.0]

Js = map(u->I + 0.1*ForwardDiff.jacobian(UJacobianWrapper(ODEHPV, 0.0, p), u), sol.u)


function predict_adjoint(p)
    p = p
    _prob = remake(prob,p=p)
    odedata = Array(solve(_prob, Vern7(),saveat=ts,sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
    estimated01 = odedata[:,:,2:3,:] ./ NF
    estimated02 = odedata[:,:,16:17,:] ./ NM
    estimated03 = odedata[:,:,9:12,:]
    prediction[:,:,1:2,:] = estimated01
    prediction[:,:,3:4,:] = estimated02
    prediction[:,:,5:8,:] = estimated03
    prediction
end


function loss_adjoint(p)
    prediction = predict_adjoint(p)
    diff = map((u,data) -> abs2.(u .- data) , prediction, data01)
    loss = sum(abs, sum(diff)) |> sqrt
    loss, prediction
end

cb = function (p,l,pred) #callback function to observe training
    println("Loss: $l")
    println("Parameters: $(p)")
    # using `remake` to re-create our `prob` with current parameters `p`
    plot(solve(remake(prob, p=p), Vern7())) |> display
    return false # Tell it to not halt the optimization. If return true, then optimization stops
end


initp = [25.0,1.425,0.986,0.81,0.93,0.10774,0.0308,0.05146,0.00508,25.0,1.016,0.525,0.4,0.4,0.04,8.0,10.0,0.670,0.538]

cb(initp,loss_adjoint(initp)...)

res = DiffEqFlux.sciml_train(loss_adjoint, initp, ADAM(0.01), cb = cb, maxiters = 150000)
res2 = DiffEqFlux.sciml_train(loss_adjoint, res.u, BFGS(), cb = cb, maxiters = 1500000, allow_f_increases=true)

1 Like

Did you try multiple shooting?

https://diffeqflux.sciml.ai/dev/examples/multiple_shooting/

Or successive growth of intervals

https://diffeqflux.sciml.ai/dev/examples/local_minima/

?

Also, decreasing tolerances helps to improve gradient accuracy.