Mixture models with binary indicator variables

I’m having trouble with a number of models which all involve a mixture model with binary indicator variables. I’m reasonably sure they are all suffering from the same underlying issue, and I’m semi-convinced the problem is with the sampling (as opposed to model specification). I still do not have a good feel for which samplers to use in what situations. Any advice on the sampling, or the models, would be appreciated.

Example 1

You can think of this as a bunch of test scores, out of 40. We want to know if people are performing at chance (ie did not study, z=0 and ψ = 0.5) or if they did study (z=1). And if they did study, what is the performance of the study group ϕ?

using Turing, StatsPlots
n_chains = 8

k = [21 17 21 18 22 31 31 34 34 35 35 36 39 36 35]
n = 40

@model function model(k, n)
    ψ = 0.5
    ϕ ~ Uniform(0.5, 1)
    z ~ filldist(Bernoulli(0.5), length(k))
    for i in eachindex(k)
        k[i] ~ Binomial(n, z[i] == 1 ? ϕ : ψ)
    end
end

chains = sample(model(k, n), PG(100, 1), MCMCThreads(), 5000, n_chains)
# plot(chains)

density(vec(chains[:ϕ]), xlim=(0, 1), lw=3, legend=false, xlabel="ϕ", ylabel="Posterior density")

It should be the case that ϕ is highly peaked on about 0.86, and that the first 5 people did not study (z=0) and the rest did study (z=1). But the posterior over the indicator variables is not correct, and the resulting distribution of ϕ is pretty off.

I recommend marginalizing out random variables when possible, even with other MCMC software. Of course, marginalizing out index variables is not always feasible. So this may not solve your problem for other models.

In my experience, PG performs poorly. At some point, some of the Turing developers were working on adding Jags-like samplers. I’m not sure what is the status of this effort. It would definitely be a valuable addition to the package.

using Turing, StatsPlots, ReverseDiff
n_chains = 4
# reverse AD might be more efficient
#Turing.setadbackend(:reversediff)

k = [21, 17, 21, 18, 22, 31, 31, 34, 34, 35, 35, 36, 39, 36, 35]
n = 40

@model function model(k, n)
    ψ = 0.5
    ϕ ~ Uniform(0.5, 1)
    z ~ filldist(Beta(1, 1), length(k))
    θ = (1 .- z) * ψ + z .* ϕ 
    k .~ Binomial.(n, θ)
end

chains = sample(model(k, n), NUTS(1000, .65), MCMCThreads(), 1000, n_chains)
# plot(chains)
density(vec(chains[:ϕ]), xlim=(0, 1), lw=3, legend=false, xlabel="ϕ", ylabel="Posterior density")

What I find here is that the mean is around .95, and the posterior mean for z[i] i in 1:5 is low, as expected.

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64 

        z[1]    0.1558    0.1130     0.0018    0.0014   4589.8165    0.9992
        z[2]    0.0912    0.0756     0.0012    0.0013   4726.6229    1.0000
        z[3]    0.1555    0.1114     0.0018    0.0016   4749.4378    1.0006
        z[4]    0.1032    0.0825     0.0013    0.0012   4633.8661    1.0001
        z[5]    0.1817    0.1208     0.0019    0.0017   4131.7823    1.0004
        z[6]    0.5782    0.1424     0.0023    0.0026   3883.3289    1.0004
        z[7]    0.5801    0.1411     0.0022    0.0023   3248.8748    0.9996
        z[8]    0.7337    0.1298     0.0021    0.0020   3188.9993    1.0007
        z[9]    0.7307    0.1308     0.0021    0.0026   2404.0744    1.0001
       z[10]    0.7800    0.1174     0.0019    0.0017   3434.4734    1.0005
       z[11]    0.7799    0.1210     0.0019    0.0021   2879.1882    1.0001
       z[12]    0.8263    0.1049     0.0017    0.0018   3573.7724    0.9998
       z[13]    0.9302    0.0608     0.0010    0.0010   5089.5668    1.0002
       z[14]    0.8260    0.1074     0.0017    0.0021   3380.7936    0.9999
       z[15]    0.7800    0.1217     0.0019    0.0028   2843.3654    1.0000
           ϕ    0.9541    0.0267     0.0004    0.0008   1204.4990    1.0005

As a side note, I received 171 warnings about numerical errors. Are you having the same experience?

Still on the list, but I don’t think there’s been active development in a while.

Thanks for the update. I hope to see this feature in the not-so-distant future.

By the way, would you recommend filing an issue about the warning messages at AdvancedHMC? The model above generates 150 to 180 warnings.

Hm, I don’t think so. I think this is mostly an issue from very uninformative priors, but perhaps @Kai_Xu has more ideas as to whether the number of numerical errors for this model merit an issue.

I can confirm that the numerical stability improved with informed priors. In the model variant below, there were only 12 warning messages.

@model function model(k, n)
    ψ = 0.5
    ϕ ~ truncated(Beta(40,10), .5, 1)
    z ~ filldist(Beta(20, 20), length(k))
    θ = (1 .- z) * ψ + z .* ϕ 
    k .~ Binomial.(n, θ)
end

If this is an unavoidable consequence of diffuse priors, I wonder if it would be better to print a summary of warnings rather than a message for each warning. For example, a single message could print

“156 numerical errors out of 2000 samples …”

Agreed, honestly the warnings in AdvancedHMC are too much and not very helpful. An ex-post diagnostic would be far more useful.

1 Like

Seconded.

Thanks for the responses so far… Hoping to test out and reply properly soon.

This actually seems quite important indeed. I am currently running into some issues where I would actually really like to have a total number of divergences/numerical errors, just like is provided in STAN. But so far I have not been able to figure out how to do this in Turing at least.

That is actually already possible with Turing! The necessary information is in the internal fields of the Chains object that gets returned when you call sample.

This issue describes what the different internal fields mean. Also I have a convenience function for checking the number of divergences encountered which can be found here.

Note that if you use Arviz.jl for plotting it also is able to highlight divergences in e.g. a pair plot.

Also important to note that the warnings that AdvancedHMC prints are not about divergences! They indicate whether the parameter, momentum, logdensity or logdensity of the momentum is infinite. In contrast, the divergences indicate an error in the numerical integration in HMC.

2 Likes

Thanks for the tip, that helps a lot!