Thank you for the comprehensive answer Martin I am trying this stuff out, I don’t really care for the discrete component assignments yeah so this version of integrating them out is attractive, except I think I am running into the classic identifiability issues associated with Bayesian Mixtures now (discussed in Betancourt’s blog here). I think roughly we need to somehow impose an order on the \mu as currently this does indeed work pretty fast but the outputs aren’t great:
using Distributions, StatsPlots, Random, KernelDensity
# Set a random seed.
Random.seed!(3)
# Construct 30 data points for each cluster.
N = 1000
# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.
f(iterators...) = vec([collect(x) for x in Iterators.product(iterators...)])
μs = f([-3:6:3, -3:6:3]...)
σ²ₜ = 0.1
β = 2
α = (β / σ²ₜ) - 1
K = length(μs)
gmm = MixtureModel(
MvNormal.(μs, √σ²ₜ)
)
x = rand(gmm, N)
plot(kde(transpose(x)), title="Real Data")
using Turing, MCMCChains, DataFrames, FillArrays, StatsBase, Zygote
function logpfun(x, μ, σ², w)
logw = log.(w)
K = length(logw)
N = size(x, 2)
sum(sum(logw[k] + logpdf(MvNormal(μ[:,k], √(σ²[k])), x[:,n]) for k in 1:K) for n in 1:N)
end
@model GMM2(x, K) = begin
D, N = size(x)
μ ~ filldist(MvNormal(zeros(D), 10), K)
σ² ~ filldist(InverseGamma(α, β), K)
w ~ Dirichlet(K, 1.0)
Turing.@addlogprob! logpfun(x, μ, σ², w)
end
Turing.setadbackend(:forwarddiff)
gmm_model = GMM2(x, K)
gmm_sampler = NUTS(0.75)
tchain = sample(
gmm_model,
gmm_sampler,
100,
)
Results in:
Chains MCMC chain (100×12×1 Array{Float64,3}):
Iterations = 1:100
Thinning interval = 1
Chains = 1
Samples per chain = 100
parameters = μ[1,1], μ[1,2], μ[1,3], μ[1,4], μ[2,1], μ[2,2], μ[2,3], μ[2,4], σ²[1], σ²[2], σ²[3], σ²[4]
internals =
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Missing Float64 Float64
μ[1,1] -0.0945 0.0850 0.0085 missing 216.8390 0.9965
μ[1,2] -0.0861 0.1018 0.0102 missing 1847.2509 0.9932
μ[1,3] -0.0961 0.0883 0.0088 missing 378.2921 0.9903
μ[1,4] -0.0804 0.0948 0.0095 missing 530.6550 0.9917
μ[2,1] -0.0239 0.1112 0.0111 missing 652.5307 0.9903
μ[2,2] -0.0169 0.0937 0.0094 missing 178.1601 1.0152
μ[2,3] -0.0283 0.1025 0.0102 missing 300.8100 0.9985
μ[2,4] -0.0268 0.0813 0.0081 missing 218.1677 0.9900
σ²[1] 8.9500 0.2832 0.0283 missing 96.9495 1.0024
σ²[2] 8.9059 0.2923 0.0292 missing 92.0745 1.0378
σ²[3] 8.9507 0.3030 0.0303 missing 121.2856 0.9926
σ²[4] 8.9215 0.2604 0.0260 missing 23.1522 1.0476
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
μ[1,1] -0.2448 -0.1526 -0.0968 -0.0448 0.0719
μ[1,2] -0.2590 -0.1636 -0.0825 -0.0211 0.1102
μ[1,3] -0.2883 -0.1556 -0.1070 -0.0352 0.0915
μ[1,4] -0.2345 -0.1589 -0.0836 -0.0126 0.0976
μ[2,1] -0.2250 -0.0867 -0.0270 0.0429 0.2036
μ[2,2] -0.1856 -0.0902 -0.0086 0.0471 0.1695
μ[2,3] -0.2024 -0.0837 -0.0262 0.0262 0.1745
μ[2,4] -0.2074 -0.0857 -0.0226 0.0355 0.1116
σ²[1] 8.4068 8.7597 8.9321 9.1382 9.4191
σ²[2] 8.4053 8.6876 8.9105 9.0624 9.5809
σ²[3] 8.4262 8.7340 8.9049 9.1691 9.5304
σ²[4] 8.4053 8.7525 8.9500 9.1158 9.3155
As it must be struggling to identify the individual clusters. Do you know how we might impose an ordering on them, I have tried initialising the values at the high density areas as suggested by @Christopher_Fisher but to no avail.