I want to start using Turing and decided to learn it by converting an example from Jags.jl
(which I have used before). I have selected the ‘rats’ example from Jags.jl for the test.
I have a working code, but 2 issues that I would like some help with…
Here is my version in Turing.jl
, excluding the data definition:
using StatsPlots, Turing
@model function rat_sampler(x, Y)
# input data properties
xBar = mean(x)
N = size(Y, 1)
# sample priors
alphaC ~ Normal(0, 100)
tauAlpha ~ Gamma(0.001, 1000)
betaC ~ Normal(0, 100)
tauBeta ~ Gamma(0.001, 1000)
tauC ~ Gamma(0.001, 1000)
# computed values - should be stored
sigma = 1.0/sqrt(tauC)
alpha0 = alphaC - betaC * xBar
# sampling alpha and beta outside of the loop, since inside we would
# have to access alpha[i] and beta[i], which would require defining
# alpha and beta as an empty vector here anyway...
alpha ~ filldist(Normal(alphaC, sqrt(1/tauAlpha)), N)
beta ~ filldist(Normal(betaC, sqrt(1/tauBeta)), N)
for i in 1:N
for j in 1:size(Y, 2)
mu = alpha[i] + beta[i] * (x[j] - xBar)
Y[i,j] ~ Normal(mu, sigma)
end
end
end
# declare the model instance on all processes
model = rat_sampler(rats["x"], rats["Y"])
# sample 4 chains
chains = sample(model, NUTS(), MCMCThreads(), 10000, 4)
# plot and save results
p = plot(chains[[:tauC, :betaC]]) # plot only selected variables
This works, but I would like to include sigma
and alpha0
in the chains, so I see their distributions in plot(chains)
- as is the case in the Jags.jl
example. Is there some way to instruct Turing to add them to the chain?
On the other hand, I do not need alpha
s and beta
s - is there some equivalent to Jags’ “monitor” functionality?
The other concern is speed: the Jags.jl
version takes under 2 minutes, while the presented Turing version takes ca. 7 minutes. I have also made a version using distributed sampling (as described in Turing documentation), but that saves only ca. 1 minute, so it is still 3 times slower than Jags.jl
.
I tried other samplers, but NUTS
is the only one I managed producing reasonable answers with.
PS: is there a Turing.jl
equivalent to Jags’ “thin” argument, for using only every n-th sample?