I have a simple Turing model I’ve been trying to implement, but I’ve been running into problems actually running it. I say simple, but my actual problem is not “small”, parameter-wise. I have about 18 groups, with about 60 items, and each item has 30 response options (treating as categorical). So, a lot of parameters. NUTS won’t run (not surprising); MH will, but crashes with a stackoverflow after the model finishes running.
For testing, I cut down the problem to 18 groups, 1 item, and 30 response options. NUTS will run, but it’s incredibly slow. MH will run very very quickly, but takes a million years to actually process things after running (I guess saving and processing the chains?)
Here’s a MWE – if anyone has pointers on what I’m doing wrong, please let me know!
using Turing, MCMCChains
# Generate some fake data; it's an attempt at a hierarchical multinomial model with 18 groups and 30 response categories
fake_dat = [5*ones(Int64, 30) for j=1:18] #Frequency of responses in each of 30 categories for 18 groups
fake_N = fill(5*30, 18) #Total N
# Model an overall distribution of responses (beta); alpha governs how closely the individual group distributions come to beta
@model multinomial_RE(data, N, ::Type{T}=Vector{Float64}) where {T} = begin
B = length(data)
R = length(data[1,1])
theta = [T(undef, R) for j=1:B]
beta ~ Dirichlet(R, 1)
alpha ~ truncated(Normal(1, 100), 0, Inf)
for b in 1:B
theta[b] ~ Dirichlet(beta * alpha)
data[b] ~ Multinomial(N[b], theta[b])
end
end
model1 = sample(multinomial_RE(fake_dat, fake_N), NUTS(0.65), 1000)
model2 = sample(multinomial_RE(fake_dat, fake_N), MH(), 50000)
So, the NUTS sampler estimates hours to finish. The metropolis sampler takes seconds for the 50k runs, but then just hangs at 100% on the progress bar. I also sometimes get some apparent numerical instability (random, but infrequent, fails with parameters listed as NaN in output). Checking with @code_warntype doesn’t show any issues that I can see.
Thoughts? I get that my original problem is huge, so I expected problems, but the MWE above should be small enough to work in a decent time, right?
EDIT: Forgot to mention, I am on Julia 1.3.0, Windows.