Recommended sampler for data with Categorical likelihood

The NUTS sampler takes a really long time in this MWE:

using Distributions
using Turing

d = Categorical([0.7, 0.29, 0.01])
N = 100
Y = rand(d, N)

@model function test(Y)
    p ~ Dirichlet([2,2,2])
    for i in 1:N
        Y[i] ~ Categorical(p)
    end
end

m = test(Y)
chain = sample(m, NUTS(), 1_000)

I have thousands of observations for my real problem so I’m going to need to choose a different sampler. I’m hoping someone can provide some suggestions for this kind of problem because I’ve been spoiled by the NUTS sampler just working (for the most part) and not having to tune a different sampler.

Just to provide another data point, you can see that Dirichlet is still much, much slower than using Beta for an equivalent model:

@model function test(Y)
    p ~ Dirichlet([2,2])
    for i in 1:N
        Y[i] ~ Categorical(p)
    end
end

m = test(Y)
chain = sample(m, NUTS(), 1_000)


@model function test2(Y)
    p ~ Beta(1,1)
    for i in 1:N
        Y[i] ~ Categorical([p, 1-p])
    end
end

m2 = test2(Y)
chain = sample(m2, NUTS(), 1_000)

# m (Dirichlet) takes 70.5s (10.7 ESS/s)
# m2 (Beta) takes 0.4s (1060 ESS/s)

No idea why.

1 Like