I’m fitting a marginalized Gaussian mixture model with Turing, and the sampler is running very slowly. Looking at the output of @code_warntype
, it appears that drawing the mixture weights from a Dirichlet distribution introduces a type instability (when I make the weight vector constant the sampler runs much faster, so that seems to be the issue). Is this a bug? Any ideas for a workaround? Thanks!
MWE:
using Turing
@model MarginalizedGMM(x, K, ::Type{T}=Vector{Float64}) where {T} = begin
N = length(x)
μ = T(undef, K)
σ = T(undef, K)
for i in 1:K
μ[i] ~ Normal(0, 5)
σ[i] ~ Gamma()
end
w ~ Dirichlet(K, 1.0)
# w = T([0.75, 0.25]) Way faster with this line instead of ↑
for i in 1:N
x[i] ~ Distributions.UnivariateGMM(μ,σ, Categorical(w))
end
return (μ::T, σ::T, w::T)
end
x = [randn(150) .- 2; randn(50) .+ 2]
gmm = MarginalizedGMM(x, 2)
varinfo = Turing.VarInfo(gmm)
spl = Turing.SampleFromPrior()
@code_warntype gmm.f(varinfo, spl, Turing.DefaultContext(), gmm)
chn = sample(gmm, NUTS(100, 0.65), 1000)