Type instability with Dirichlet distribution in Turing

I’m fitting a marginalized Gaussian mixture model with Turing, and the sampler is running very slowly. Looking at the output of @code_warntype, it appears that drawing the mixture weights from a Dirichlet distribution introduces a type instability (when I make the weight vector constant the sampler runs much faster, so that seems to be the issue). Is this a bug? Any ideas for a workaround? Thanks!

MWE:

using Turing
@model MarginalizedGMM(x, K, ::Type{T}=Vector{Float64}) where {T} = begin
    N = length(x)
    μ = T(undef, K)
    σ = T(undef, K)
    for i in 1:K
        μ[i] ~ Normal(0, 5)
        σ[i] ~ Gamma()
    end
    w ~ Dirichlet(K, 1.0)
    # w = T([0.75, 0.25]) Way faster with this line instead of ↑
    for i in 1:N
      x[i] ~ Distributions.UnivariateGMM(μ,σ, Categorical(w))
    end
    return (μ::T, σ::T, w::T)
end


x = [randn(150) .- 2; randn(50) .+ 2]
gmm = MarginalizedGMM(x, 2)
varinfo = Turing.VarInfo(gmm)
spl = Turing.SampleFromPrior()
@code_warntype gmm.f(varinfo, spl, Turing.DefaultContext(), gmm)

chn = sample(gmm, NUTS(100, 0.65), 1000)
1 Like

Please open an issue so it doesn’t get forgotten. I will take a look when I get some time.

And btw it’s natural for it to be faster with that line removed because then you are not accumulating the logpdf from the Dirichlet and not differentiating it wrt to the transformed parameters which that line would do. But the type instability is what I am concerned about.

Issue opened: https://github.com/TuringLang/Turing.jl/issues/1276

And yeah, I wouldn’t be surprised to see it run a bit slower when it needs to sample the weights as well, but the slowdown with the type instability is ~25x, which did get my attention.

1 Like