Bayesian estimation with Multinomial Likelihood and Dirichlet prior

I’m trying to fit a parasite dynamic model to some experimental data using Turing.jl, but I’m having some trouble specifying the prior variance, σ. It works fine if I use an InverseGamma distribution (i.e., resulting parameters seem reasonable), but since the data were drawn from a multinomial distribution, an InverseGamma prior and a MvNormal likelihood seem inappropriate. My understanding is that I would then use a Dirichlet prior and a multinomial likelihood (?). How would one implement this? The answer is definitely “learn more about Bayesian statistics”, but getting this working will go a long way toward helping me understand what’s going on without having to second guess my code.

The data are the number of preinfective larvae and infective larvae counted in multiple hosts at different days post-infection. I think the likelihood should be multinomial because at any given day there could be preinfective, infective, and/or dead larvae (although I don’t have data for the latter, so maybe Binomial? I get the same warnings in that case).

My attempts at using Multinomial or Binomial likelihoods result in endless warnings:

┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, true, false, true)
└  @ AdvancedHMC C:\Users\alexa\.julia\packages\AdvancedHMC\P9wqk\src\hamiltonian.jl:47

without ever getting a result.

The data are not great, but I’m hoping to get some information on development rates and mortality rates for the larvae.
I’m working off the Lotka-Volterra example at https://turing.ml/dev/tutorials/10-bayesiandiffeq/


using Turing, Distributions, DifferentialEquations

using MCMCChains, Plots, StatsPlots


#Model of the experimental infection protocol
const t_1 = 5.0
const γ = 1/t_1

function DevDelay_simpODE(du,u,p,t)
    #p = [muU, muI, m, tau] = [mortality rate preinfective, mort rate infective, initial # of U, delay period before developing to infectivity (1/development rate)]
  if t<t_1
    J = p[3]*γ - p[1]*u[1]*u[1] #inject some larvae at very start of experiment
  else
    J = 0.0
  end
  if t<p[4]
    S = 0.0
  else
    S = 1.0
  end
  du[1] = dU = J - p[1]*u[1]*u[1] - γ*u[1]*S #preinfective
  du[2] = dI = γ*u[1]*S - p[2]*u[2]          #infective
  du[3] = dD = p[1]*u[1]*u[1] + p[2]*u[2]    #dead
end

#Data (number of U and I in a host on a given day post-infection)
datax11 = [25,25,25,35,35,35,35,35,42,42,42,42,42,49,49,49,49,49,48,48,48,55,55,55,55,55,62,62,62,62,62,69,69,69,69,69,76,76,76,76,76].+5
datayU11 = [7,8,5,7,9,2,8,2,8,14,3,16,4,2,7,5,4,6,3,13,2,4,14,6,12,2,3,2,1,3,5,4,2,3,1,1,1,6,1,3,8]
datayI11 = [0,0,0,0,0,0,0,0,0,9,6,9,8,1,4,0,2,2,0,3,1,7,6,8,4,8,5,14,6,9,11,7,6,5,5,7,4,3,2,5,13]

data_11 = [datayU11 datayI11]'

Turing.setadbackend(:forwarddiff)

@model function fit11_5(data)
    tspan11 = (0.0,81.0)
    u0 = [0.0;0.0;0.0]
    #Priors
    # σ ~ InverseGamma(2, 3) #This "works"
    σ ~ Dirichlet(ones(2)/2) #But should be this if likelihood is Multinomial?

    muU ~ truncated(Normal(0.002,0.002),0.001,0.02)
    muI ~ truncated(Normal(0.002,0.002),0.001,0.02)
    m ~ truncated(Normal(12.5,5.0),0,90)
    tau ~ truncated(Normal(50.0,30.0),10.0,120.0)

    p11 = [muU, muI, m, tau]
    prob11 = ODEProblem(DevDelay_simpODE,u0,tspan11,p11)
    predicted = solve(prob11,Tsit5(),saveat=datax11)
    for i = 1:length(predicted)
        # data[:,i] ~ MvNormal(predicted[i][1:2], σ) #This "works" with InverseGamma
        data[:,i] ~ Multinomial(2, σ) #How do you incorporate σ into likelihood function? Does it act as the probability vector?
    end
end

model = fit11_5(data_11)
chain11_5 = sample(model, NUTS(.65),10)
plot(chain11_5)

This is a small part of a much more complicated fitting process. I was having some issues with convergence while using MLE, probably because of the small sample sizes. I thought I might have more luck with Bayes since I have some prior information I could incorporate. This is my first foray into Bayesian statistics, so any help would be greatly appreciated.

First I recommend using Binomial instead of Multinomial if your data only has two outcomes.

What do you mean by prior variance for Binomial distribution?

In your example the predicted is not used in the Binomial setting, so it is not clear what your model is supposed to be.

1 Like

Thanks for replying. Sorry I’m slow to respond.

Of course predicted needs to be used, I don’t know what I was thinking! Here is a more thoughtful attempt:

@model function fit11_5(data)
    tspan11 = (0.0,81.0)
    u0 = [0.0;0.0;0.0]

    σ ~ Beta(1,1)
    muU ~ truncated(Normal(0.002,0.002),0.001,0.02) 
    muI ~ truncated(Normal(0.002,0.002),0.001,0.02)
    m ~ truncated(Normal(12.5,5.0),0,90)
    tau ~ truncated(Normal(50.0,30.0),10.0,120.0)
    
    p11 = [muU, muI, m, tau]
    prob11 = ODEProblem(DevDelay_simpODE,u0,tspan11,p11)
    predicted = solve(prob11,Tsit5(),saveat=datax11)
    for i = 1:length(predicted)
        pred_sum = Int(round(sum(predicted[i][1:2]))) #Total number of larvae predicted by model with given parameters
        pred_prob = predicted[i][1] / sum(predicted[i][1:2]) #Proportion of larvae that are U
        data[1,i] ~ Binomial(pred_sum,pred_prob) #likelihood of observing data assuming it was drawn from a binomial with # of trials (n) = pred_sum and success rate (p) = pred_prob
    end
end

model = fit11_5(data_11)
@time chain11_5 = sample(model, NUTS(.65),10)
plot(chain11_5)

My understanding is that the probability of counting a preinfective larva (U) in one trial is initially distributed as Beta(1,1), and then this distribution is updated by calculating the likelihood for each observation (?)

So what I"ve done in the for loop is calculated the total number of larvae counted on a given day pred_sum (used as the number of trials), and the proportion of total larvae that are preinfective (U), pred_prob (used as the predicted success rate). I’m still confused about how the prior σ enters into it (what I confusingly meant by prior variance of the binomial). I hope my intention is clearer now. I am still very new to the terminology.

Looks better now. I think the place for prior information should be in the predicted somehow. Otherwise it is just prior information on the accuracy of the prediction. Maybe your prior knowledge is about muU and muI and no sigma is needed at all?

There might still be a problem in the Binomial model. If I understand correctly the total number of observations data[1,i] + data[2,i] is not used. I think that should be used instead of pred_sum.

1 Like

I think you may be right about not needing sigma. I think I was confused by the Lotka-Volterra example above where there were priors for the ode model parameters, but then also one for the variance of the normal distribution in the likelihood. So I just assumed there needed to be some variance parameter in the likelihood without really thinking about it. I also changed pred_sum as you suggested, which makes more sense, and now it works! I was getting a lot of numerical issue warnings but these were resolved by using BinomialLogit instead of Binomial. The fit to the data is not great, but I still need to figure out my priors better and try a few of the different samplers. Thanks a lot for your help!

This is what I ended up with:

    tspan11 = (0.0,81.0)
    u0 = [0.0;0.0;0.0]

    #Priors

    muU ~ truncated(Normal(0.001,0.002),0.001,0.02)
    muI ~ truncated(Normal(0.001,0.002),0.001,0.02)
    m ~ truncated(Normal(12.5,5.0),0,90)
    tau ~ truncated(Normal(50.0,50.0),10.0,120.0)

    p11 = [muU, muI, m, tau]
    prob11 = ODEProblem(DevDelay_simpODE,u0,tspan11,p11)
    predicted = solve(prob11,Tsit5(),saveat=datax11)
    for i = 1:length(predicted)
        lar_sum = data[1,i] + data[2,i]#Total number of larvae counted
        pred_prob = predicted[i][1] / sum(predicted[i][1:2]) #Proportion of larvae that are U
        data[1,i] ~ BinomialLogit(lar_sum,pred_prob) #likelihood of observing data assuming it was drawn from a binomial with # of trials (n) = pred_sum and success rate (p) = pred_prob
    end
end

model = fit11_5(data_11)
@time chain11_5 = sample(model, NUTS(.65),500)
# @time chain11_5 = sample(model, HMC(0.01,10),1000)
plot(chain11_5)

After playing around a bit, I found the poor fit was due to my priors (which were arbitrary at this point). By changing the priors to (0.0, Inf) and using the MH sampler it runs really fast, high ESS, and gives a good result. Thanks again for your help!

-Alex

2 Likes