Strange issues fitting GMMs in Turing.jl by extending a fairly simple example

I have been extending the Turing example on Bayesian GMMs today in a number of ways, trying it out with non-symmetric 2D Gaussian locations and more clusters than the two in the example I have linked, and namely with some more tightly distributed components. This is where I have run into a problem I am struggling to diagnose and wondering if it is something looking into further. I first pick my parameters and generate some data:

using Distributions, StatsPlots, Random, KernelDensity

# Set a random seed.
Random.seed!(3)

N = 100

# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.
f(iterators...) = vec([collect(x) for x in Iterators.product(iterators...)])
μs = f([0:1:1, 0:1:1]...)

σ_true = 0.25

β = 5
α = (β / σ_true) - 1

K = length(μs)

gmm = MixtureModel(
    MvNormal.(μs, σ_true)
)

x = rand(gmm, N)

plot(kde(transpose(x)), title="Real Data")

This all goes smoothly and results in 4 components at each corner of the [0,1] square, I then define my model which is the same as the one in the linked example except we also learn a variance common to all the components and we aren’t restricted to the “x” and “y” locations being equal for each component, this works pretty well:


using Turing, MCMCChains, DataFrames, FillArrays, StatsBase, Zygote

@model GaussianMixtureModel(x, K) = begin
    
    D, N = size(x)

    # Draw the parameters for cluster 1.
    μ1 ~ filldist(Normal(), K)
    μ2 ~ filldist(Normal(), K)
    # Uncomment the following lines to draw the weights for the K clusters 
    # from a Dirichlet distribution.
    
    σ ~ InverseGamma(α, β)

    w ~ Dirichlet(K, 1.0)
    
    # Comment out this line if you instead want to draw the weights.
    # w = Fill(1 / K, K)
    
    # Draw assignments for each datum and generate it from a multivariate normal.
    k = Vector{Int}(undef, N)
    for i in 1:N
        k[i] ~ Categorical(w)
        x[:,i] ~ MvNormal([μ1[k[i]], μ2[k[i]]], σ)
    end
    return k
end

Turing.setadbackend!(:forwarddiff)

gmm_model = GaussianMixtureModel(x, K)
gmm_sampler = Gibbs(PG(100, :k), HMC(0.05, 10, :μ1, :μ2, :σ))
tchain = sample(
    gmm_model,
    gmm_sampler,
    100
)

As we get the following result:

Chains MCMC chain (100×9×1 Array{Float64,3}):

Iterations        = 1:100
Thinning interval = 1
Chains            = 1
Samples per chain = 100
parameters        = μ1[1], μ1[2], μ1[3], μ1[4], μ2[1], μ2[2], μ2[3], μ2[4], σ
internals         =

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat
      Symbol   Float64   Float64    Float64   Missing    Float64   Float64

       μ1[1]   -0.0056    0.0748     0.0075   missing    65.7362    0.9919
       μ1[2]    0.8684    0.0663     0.0066   missing    48.8895    1.0249
       μ1[3]   -0.0227    0.0415     0.0041   missing    49.3505    1.0110
       μ1[4]    1.0091    0.0584     0.0058   missing   154.4181    0.9965
       μ2[1]   -0.0428    0.0649     0.0065   missing    96.1578    0.9958
       μ2[2]   -0.0445    0.0612     0.0061   missing    55.1864    0.9903
       μ2[3]    1.0109    0.0497     0.0050   missing    40.8658    1.0210
       μ2[4]    0.8448    0.0649     0.0065   missing   106.4267    0.9913
           σ    0.2560    0.0161     0.0016   missing    30.7763    1.0707

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

       μ1[1]   -0.1235   -0.0680   -0.0088    0.0606    0.1347
       μ1[2]    0.7557    0.8169    0.8696    0.9140    0.9999
       μ1[3]   -0.1105   -0.0533   -0.0203    0.0045    0.0457
       μ1[4]    0.9032    0.9717    1.0161    1.0417    1.1016
       μ2[1]   -0.1538   -0.0903   -0.0455    0.0004    0.0760
       μ2[2]   -0.1634   -0.0864   -0.0433    0.0026    0.0553
       μ2[3]    0.9242    0.9666    1.0141    1.0453    1.0953
       μ2[4]    0.7104    0.8024    0.8390    0.8825    0.9661
           σ    0.2287    0.2435    0.2552    0.2666    0.2883

However, everything falls apart when I simply reduce σ_true a bit to say 0.05:

Chains MCMC chain (100×9×1 Array{Float64,3}):

Iterations        = 1:100
Thinning interval = 1
Chains            = 1
Samples per chain = 100
parameters        = μ1[1], μ1[2], μ1[3], μ1[4], μ2[1], μ2[2], μ2[3], μ2[4], σ
internals         =

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat
      Symbol   Float64   Float64    Float64   Missing   Float64   Float64

       μ1[1]    1.2987    0.0000     0.0000   missing       NaN       NaN
       μ1[2]    0.8141    0.0000     0.0000   missing    2.0408    0.9899
       μ1[3]    0.2800    0.0000     0.0000   missing       NaN       NaN
       μ1[4]   -0.3730    0.0000     0.0000   missing       NaN       NaN
       μ2[1]   -0.0028    0.0000     0.0000   missing    2.0408    0.9899
       μ2[2]    1.0843    0.0000     0.0000   missing    2.0408    0.9899
       μ2[3]   -0.9647    0.0000     0.0000   missing    2.0408    0.9899
       μ2[4]   -0.0501    0.0000     0.0000   missing       NaN       NaN
           σ    0.0544    0.0000     0.0000   missing    2.0408    0.9899

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

       μ1[1]    1.2987    1.2987    1.2987    1.2987    1.2987
       μ1[2]    0.8141    0.8141    0.8141    0.8141    0.8141
       μ1[3]    0.2800    0.2800    0.2800    0.2800    0.2800
       μ1[4]   -0.3730   -0.3730   -0.3730   -0.3730   -0.3730
       μ2[1]   -0.0028   -0.0028   -0.0028   -0.0028   -0.0028
       μ2[2]    1.0843    1.0843    1.0843    1.0843    1.0843
       μ2[3]   -0.9647   -0.9647   -0.9647   -0.9647   -0.9647
       μ2[4]   -0.0501   -0.0501   -0.0501   -0.0501   -0.0501
           σ    0.0544    0.0544    0.0544    0.0544    0.0544

Here it seems that no new values are accepted throughout sampling and it just remains at whatever it manages to initially sample, I am not sure why sampling would break down in this way, unless I am missing something obvious, I have spoken to a couple of people and read around, know that MCMC struggles with GMMs but cannot really think where to start in trying to fix this, so welcome any ideas / improvements to the more general GMM model I have formulated.

Any help would be appreciated, thanks!

One thing might be that this should be a TArray:

k = TArray(Int, N)
1 Like

There might be a few issues here. One potential issue is that the posterior density is concentrated in a small area when sigma is small, making it hard for the sampler to find high density areas. This might be exacerbated by performance issues with PG. You might try reparameterization or initializing the sampler in high density areas. In addition, marginalization might help. See this other thread.

2 Likes

Hi, can you integrate you the discrete variables or do you want to sample them with Turing?

Anyways, so here is an implementation of the GMM model when using particle Gibbs.

@model GaussianMixtureModel(x, K) = begin

    D, N = size(x)

    μ ~ filldist(MvNormal(zeros(D), 1), K)
    σ ~ filldist(InverseGamma(α, β), K)

    w ~ Dirichlet(K, 1.0)

    k = TArray(Int, N)
    for i in 1:N
        k[i] ~ Categorical(w)
	x[:,i] ~ MvNormal(μ[:,k[i]], sqrt(σ[k[i]]))
    end
    return k
end

gmm_model = GaussianMixtureModel(x, K)
gmm_sampler = Gibbs(PG(10, :k), HMC(0.05, 10, :μ, :σ, :w))
tchain = sample(
    gmm_model,
    gmm_sampler,
    100
)

Alternatively, you can use the GibbsConditional to sample the discrete variables. This should be more efficient. The function for the Gibbs conditional should look something like this:

function cond_k(c; x=x)
        K = length(c.w)
        N = size(x, 2)
	dists = map(1:N) do n
		logp = map(k -> log(c.w[k]) + logpdf(MvNormal(c.μ[:,k], sqrt(c.σ[k])), x[:,n]), 1:K)
		return Categorical(softmax!(logp))
	end
	return arraydist(dists)
end

gmm_sampler = GibbsConditional(:k, cond_k), HMC(0.05, 10, :μ, :σ, :w))

but I haven’t tried it. So you might have to fix some minor errors I made.

If you only want speed and don’t care about the discrete parameters then you might want to write:

function logpfun(x, μ, σ, w)
    logw = log.(w)
    K = length(logw)
    N = size(x,2)
    sum( sum( logw[k] + logpdf(MvNormal(μ[:,k], sqrt(σ[k]), x[:,n]) for k in 1:K) for n in 1:N)
end

@model GaussianMixtureModel(x, K) = begin

    D, N = size(x)

    μ ~ filldist(MvNormal(zeros(D), 1), K)
    σ ~ filldist(InverseGamma(α, β), K)
    w ~ Dirichlet(K, 1.0)

    Turing.@addlogprob! logpfun(x, μ, σ, w)
end

gmm_model = GaussianMixtureModel(x, K)
gmm_sampler = NUTS(0.75)
tchain = sample(
    gmm_model,
    gmm_sampler,
    100
)
2 Likes

Thank you for the comprehensive answer Martin :slight_smile: I am trying this stuff out, I don’t really care for the discrete component assignments yeah so this version of integrating them out is attractive, except I think I am running into the classic identifiability issues associated with Bayesian Mixtures now (discussed in Betancourt’s blog here). I think roughly we need to somehow impose an order on the \mu as currently this does indeed work pretty fast but the outputs aren’t great:

using Distributions, StatsPlots, Random, KernelDensity

# Set a random seed.
Random.seed!(3)

# Construct 30 data points for each cluster.
N = 1000

# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.

f(iterators...) = vec([collect(x) for x in Iterators.product(iterators...)])
μs = f([-3:6:3, -3:6:3]...)

σ²ₜ = 0.1

β = 2
α = (β / σ²ₜ) - 1

K = length(μs)

gmm = MixtureModel(
    MvNormal.(μs, √σ²ₜ)
)

x = rand(gmm, N)

plot(kde(transpose(x)), title="Real Data")

using Turing, MCMCChains, DataFrames, FillArrays, StatsBase, Zygote

function logpfun(x, μ, σ², w)
    logw = log.(w)
    K = length(logw)
    N = size(x, 2)
    sum(sum(logw[k] + logpdf(MvNormal(μ[:,k], √(σ²[k])), x[:,n]) for k in 1:K) for n in 1:N)
end

@model GMM2(x, K) = begin

    D, N = size(x)

    μ ~ filldist(MvNormal(zeros(D), 10), K)
    σ² ~ filldist(InverseGamma(α, β), K)
    w ~ Dirichlet(K, 1.0)

    Turing.@addlogprob! logpfun(x, μ, σ², w)
end

Turing.setadbackend(:forwarddiff)

gmm_model = GMM2(x, K)
gmm_sampler = NUTS(0.75)
tchain = sample(
    gmm_model,
    gmm_sampler,
    100,
)

Results in:

Chains MCMC chain (100×12×1 Array{Float64,3}):

Iterations        = 1:100
Thinning interval = 1
Chains            = 1
Samples per chain = 100
parameters        = μ[1,1], μ[1,2], μ[1,3], μ[1,4], μ[2,1], μ[2,2], μ[2,3], μ[2,4], σ²[1], σ²[2], σ²[3], σ²[4]
internals         =

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat
      Symbol   Float64   Float64    Float64   Missing     Float64   Float64

      μ[1,1]   -0.0945    0.0850     0.0085   missing    216.8390    0.9965
      μ[1,2]   -0.0861    0.1018     0.0102   missing   1847.2509    0.9932
      μ[1,3]   -0.0961    0.0883     0.0088   missing    378.2921    0.9903
      μ[1,4]   -0.0804    0.0948     0.0095   missing    530.6550    0.9917
      μ[2,1]   -0.0239    0.1112     0.0111   missing    652.5307    0.9903
      μ[2,2]   -0.0169    0.0937     0.0094   missing    178.1601    1.0152
      μ[2,3]   -0.0283    0.1025     0.0102   missing    300.8100    0.9985
      μ[2,4]   -0.0268    0.0813     0.0081   missing    218.1677    0.9900
       σ²[1]    8.9500    0.2832     0.0283   missing     96.9495    1.0024
       σ²[2]    8.9059    0.2923     0.0292   missing     92.0745    1.0378
       σ²[3]    8.9507    0.3030     0.0303   missing    121.2856    0.9926
       σ²[4]    8.9215    0.2604     0.0260   missing     23.1522    1.0476

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

      μ[1,1]   -0.2448   -0.1526   -0.0968   -0.0448    0.0719
      μ[1,2]   -0.2590   -0.1636   -0.0825   -0.0211    0.1102
      μ[1,3]   -0.2883   -0.1556   -0.1070   -0.0352    0.0915
      μ[1,4]   -0.2345   -0.1589   -0.0836   -0.0126    0.0976
      μ[2,1]   -0.2250   -0.0867   -0.0270    0.0429    0.2036
      μ[2,2]   -0.1856   -0.0902   -0.0086    0.0471    0.1695
      μ[2,3]   -0.2024   -0.0837   -0.0262    0.0262    0.1745
      μ[2,4]   -0.2074   -0.0857   -0.0226    0.0355    0.1116
       σ²[1]    8.4068    8.7597    8.9321    9.1382    9.4191
       σ²[2]    8.4053    8.6876    8.9105    9.0624    9.5809
       σ²[3]    8.4262    8.7340    8.9049    9.1691    9.5304
       σ²[4]    8.4053    8.7525    8.9500    9.1158    9.3155

As it must be struggling to identify the individual clusters. Do you know how we might impose an ordering on them, I have tried initialising the values at the high density areas as suggested by @Christopher_Fisher but to no avail.

Part of the problem is that log density of mixture components cannot be added with sum. Instead, you need to use logsumexp from StatsFuns in the inner summation. logsumexp is a numerically stable and accurate way to sum:

julia> w = [.2, .8]
2-element Array{Float64,1}:
 0.2
 0.8

julia> dens = [2.0,1.5]
2-element Array{Float64,1}:
 2.0
 1.5

julia> log(w'*dens)
0.47000362924573563

julia> logsumexp(log.(w) .+ log.(dens)) # =)
0.47000362924573563

julia> sum(log.(w) .+ log.(dens)) # =(
-0.7339691750802003

It appears to work better with the following function:

function logpfun(x, μ, σ², w)
    logw = log.(w)
    K = length(logw)
    N = size(x, 2)
    sum(logsumexp(logw[k] + logpdf(MvNormal(μ[:,k], √(σ²[k])), x[:,n]) for k in 1:K) for n in 1:N)
end

You should increase the number of samples because effective sample size is low and rhat is high in some cases.

2 Likes

Oh yeah, my fault. Sorry that was a stupid mistake.

No worries. I do something like that at least once a day. :laughing:

2 Likes

Thanks for the help @Christopher_Fisher and @trappmartin, it is working now I think to some extent, initialising at the correct values also helps in this case for the sampling when variance is small. I have one further question which is really more of a hopeful one, do you have any tips for scaling this such that it might work with more clusters, naturally I feel larger number of samples / more data makes sense for if I wanted to have say 25 clusters, but wondered if there were any further tips you might have to improve the model? I presume the majority of the difficulty here arises from the multi-modality that grows in the number of components of the mixture, and so it might just be an inevitable uphill struggle.

As you add more clusters, ReverseDiff will have a larger and larger performance advantage. Unfortunately, there is a performance gap between ReverseDiff and Stan. But if I recall correctly, Stan also struggles with this particular type of model. Last time I inquired, there were a few efforts are developing new reverse AD packages, but those details were lost on Slack.