I have been extending the Turing example on Bayesian GMMs today in a number of ways, trying it out with non-symmetric 2D Gaussian locations and more clusters than the two in the example I have linked, and namely with some more tightly distributed components. This is where I have run into a problem I am struggling to diagnose and wondering if it is something looking into further. I first pick my parameters and generate some data:
using Distributions, StatsPlots, Random, KernelDensity
# Set a random seed.
Random.seed!(3)
N = 100
# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.
f(iterators...) = vec([collect(x) for x in Iterators.product(iterators...)])
μs = f([0:1:1, 0:1:1]...)
σ_true = 0.25
β = 5
α = (β / σ_true) - 1
K = length(μs)
gmm = MixtureModel(
MvNormal.(μs, σ_true)
)
x = rand(gmm, N)
plot(kde(transpose(x)), title="Real Data")
This all goes smoothly and results in 4 components at each corner of the [0,1] square, I then define my model which is the same as the one in the linked example except we also learn a variance common to all the components and we aren’t restricted to the “x” and “y” locations being equal for each component, this works pretty well:
using Turing, MCMCChains, DataFrames, FillArrays, StatsBase, Zygote
@model GaussianMixtureModel(x, K) = begin
D, N = size(x)
# Draw the parameters for cluster 1.
μ1 ~ filldist(Normal(), K)
μ2 ~ filldist(Normal(), K)
# Uncomment the following lines to draw the weights for the K clusters
# from a Dirichlet distribution.
σ ~ InverseGamma(α, β)
w ~ Dirichlet(K, 1.0)
# Comment out this line if you instead want to draw the weights.
# w = Fill(1 / K, K)
# Draw assignments for each datum and generate it from a multivariate normal.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ Categorical(w)
x[:,i] ~ MvNormal([μ1[k[i]], μ2[k[i]]], σ)
end
return k
end
Turing.setadbackend!(:forwarddiff)
gmm_model = GaussianMixtureModel(x, K)
gmm_sampler = Gibbs(PG(100, :k), HMC(0.05, 10, :μ1, :μ2, :σ))
tchain = sample(
gmm_model,
gmm_sampler,
100
)
As we get the following result:
Chains MCMC chain (100×9×1 Array{Float64,3}):
Iterations = 1:100
Thinning interval = 1
Chains = 1
Samples per chain = 100
parameters = μ1[1], μ1[2], μ1[3], μ1[4], μ2[1], μ2[2], μ2[3], μ2[4], σ
internals =
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Missing Float64 Float64
μ1[1] -0.0056 0.0748 0.0075 missing 65.7362 0.9919
μ1[2] 0.8684 0.0663 0.0066 missing 48.8895 1.0249
μ1[3] -0.0227 0.0415 0.0041 missing 49.3505 1.0110
μ1[4] 1.0091 0.0584 0.0058 missing 154.4181 0.9965
μ2[1] -0.0428 0.0649 0.0065 missing 96.1578 0.9958
μ2[2] -0.0445 0.0612 0.0061 missing 55.1864 0.9903
μ2[3] 1.0109 0.0497 0.0050 missing 40.8658 1.0210
μ2[4] 0.8448 0.0649 0.0065 missing 106.4267 0.9913
σ 0.2560 0.0161 0.0016 missing 30.7763 1.0707
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
μ1[1] -0.1235 -0.0680 -0.0088 0.0606 0.1347
μ1[2] 0.7557 0.8169 0.8696 0.9140 0.9999
μ1[3] -0.1105 -0.0533 -0.0203 0.0045 0.0457
μ1[4] 0.9032 0.9717 1.0161 1.0417 1.1016
μ2[1] -0.1538 -0.0903 -0.0455 0.0004 0.0760
μ2[2] -0.1634 -0.0864 -0.0433 0.0026 0.0553
μ2[3] 0.9242 0.9666 1.0141 1.0453 1.0953
μ2[4] 0.7104 0.8024 0.8390 0.8825 0.9661
σ 0.2287 0.2435 0.2552 0.2666 0.2883
However, everything falls apart when I simply reduce σ_true
a bit to say 0.05
:
Chains MCMC chain (100×9×1 Array{Float64,3}):
Iterations = 1:100
Thinning interval = 1
Chains = 1
Samples per chain = 100
parameters = μ1[1], μ1[2], μ1[3], μ1[4], μ2[1], μ2[2], μ2[3], μ2[4], σ
internals =
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Missing Float64 Float64
μ1[1] 1.2987 0.0000 0.0000 missing NaN NaN
μ1[2] 0.8141 0.0000 0.0000 missing 2.0408 0.9899
μ1[3] 0.2800 0.0000 0.0000 missing NaN NaN
μ1[4] -0.3730 0.0000 0.0000 missing NaN NaN
μ2[1] -0.0028 0.0000 0.0000 missing 2.0408 0.9899
μ2[2] 1.0843 0.0000 0.0000 missing 2.0408 0.9899
μ2[3] -0.9647 0.0000 0.0000 missing 2.0408 0.9899
μ2[4] -0.0501 0.0000 0.0000 missing NaN NaN
σ 0.0544 0.0000 0.0000 missing 2.0408 0.9899
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
μ1[1] 1.2987 1.2987 1.2987 1.2987 1.2987
μ1[2] 0.8141 0.8141 0.8141 0.8141 0.8141
μ1[3] 0.2800 0.2800 0.2800 0.2800 0.2800
μ1[4] -0.3730 -0.3730 -0.3730 -0.3730 -0.3730
μ2[1] -0.0028 -0.0028 -0.0028 -0.0028 -0.0028
μ2[2] 1.0843 1.0843 1.0843 1.0843 1.0843
μ2[3] -0.9647 -0.9647 -0.9647 -0.9647 -0.9647
μ2[4] -0.0501 -0.0501 -0.0501 -0.0501 -0.0501
σ 0.0544 0.0544 0.0544 0.0544 0.0544
Here it seems that no new values are accepted throughout sampling and it just remains at whatever it manages to initially sample, I am not sure why sampling would break down in this way, unless I am missing something obvious, I have spoken to a couple of people and read around, know that MCMC struggles with GMMs but cannot really think where to start in trying to fix this, so welcome any ideas / improvements to the more general GMM model I have formulated.
Any help would be appreciated, thanks!