Strange issues fitting GMMs in Turing.jl by extending a fairly simple example

Thank you for the comprehensive answer Martin :slight_smile: I am trying this stuff out, I don’t really care for the discrete component assignments yeah so this version of integrating them out is attractive, except I think I am running into the classic identifiability issues associated with Bayesian Mixtures now (discussed in Betancourt’s blog here). I think roughly we need to somehow impose an order on the \mu as currently this does indeed work pretty fast but the outputs aren’t great:

using Distributions, StatsPlots, Random, KernelDensity

# Set a random seed.
Random.seed!(3)

# Construct 30 data points for each cluster.
N = 1000

# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.

f(iterators...) = vec([collect(x) for x in Iterators.product(iterators...)])
μs = f([-3:6:3, -3:6:3]...)

σ²ₜ = 0.1

β = 2
α = (β / σ²ₜ) - 1

K = length(μs)

gmm = MixtureModel(
    MvNormal.(μs, √σ²ₜ)
)

x = rand(gmm, N)

plot(kde(transpose(x)), title="Real Data")

using Turing, MCMCChains, DataFrames, FillArrays, StatsBase, Zygote

function logpfun(x, μ, σ², w)
    logw = log.(w)
    K = length(logw)
    N = size(x, 2)
    sum(sum(logw[k] + logpdf(MvNormal(μ[:,k], √(σ²[k])), x[:,n]) for k in 1:K) for n in 1:N)
end

@model GMM2(x, K) = begin

    D, N = size(x)

    μ ~ filldist(MvNormal(zeros(D), 10), K)
    σ² ~ filldist(InverseGamma(α, β), K)
    w ~ Dirichlet(K, 1.0)

    Turing.@addlogprob! logpfun(x, μ, σ², w)
end

Turing.setadbackend(:forwarddiff)

gmm_model = GMM2(x, K)
gmm_sampler = NUTS(0.75)
tchain = sample(
    gmm_model,
    gmm_sampler,
    100,
)

Results in:

Chains MCMC chain (100×12×1 Array{Float64,3}):

Iterations        = 1:100
Thinning interval = 1
Chains            = 1
Samples per chain = 100
parameters        = μ[1,1], μ[1,2], μ[1,3], μ[1,4], μ[2,1], μ[2,2], μ[2,3], μ[2,4], σ²[1], σ²[2], σ²[3], σ²[4]
internals         =

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat
      Symbol   Float64   Float64    Float64   Missing     Float64   Float64

      μ[1,1]   -0.0945    0.0850     0.0085   missing    216.8390    0.9965
      μ[1,2]   -0.0861    0.1018     0.0102   missing   1847.2509    0.9932
      μ[1,3]   -0.0961    0.0883     0.0088   missing    378.2921    0.9903
      μ[1,4]   -0.0804    0.0948     0.0095   missing    530.6550    0.9917
      μ[2,1]   -0.0239    0.1112     0.0111   missing    652.5307    0.9903
      μ[2,2]   -0.0169    0.0937     0.0094   missing    178.1601    1.0152
      μ[2,3]   -0.0283    0.1025     0.0102   missing    300.8100    0.9985
      μ[2,4]   -0.0268    0.0813     0.0081   missing    218.1677    0.9900
       σ²[1]    8.9500    0.2832     0.0283   missing     96.9495    1.0024
       σ²[2]    8.9059    0.2923     0.0292   missing     92.0745    1.0378
       σ²[3]    8.9507    0.3030     0.0303   missing    121.2856    0.9926
       σ²[4]    8.9215    0.2604     0.0260   missing     23.1522    1.0476

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

      μ[1,1]   -0.2448   -0.1526   -0.0968   -0.0448    0.0719
      μ[1,2]   -0.2590   -0.1636   -0.0825   -0.0211    0.1102
      μ[1,3]   -0.2883   -0.1556   -0.1070   -0.0352    0.0915
      μ[1,4]   -0.2345   -0.1589   -0.0836   -0.0126    0.0976
      μ[2,1]   -0.2250   -0.0867   -0.0270    0.0429    0.2036
      μ[2,2]   -0.1856   -0.0902   -0.0086    0.0471    0.1695
      μ[2,3]   -0.2024   -0.0837   -0.0262    0.0262    0.1745
      μ[2,4]   -0.2074   -0.0857   -0.0226    0.0355    0.1116
       σ²[1]    8.4068    8.7597    8.9321    9.1382    9.4191
       σ²[2]    8.4053    8.6876    8.9105    9.0624    9.5809
       σ²[3]    8.4262    8.7340    8.9049    9.1691    9.5304
       σ²[4]    8.4053    8.7525    8.9500    9.1158    9.3155

As it must be struggling to identify the individual clusters. Do you know how we might impose an ordering on them, I have tried initialising the values at the high density areas as suggested by @Christopher_Fisher but to no avail.