Help converting Jags.jl example to Turing

I want to start using Turing and decided to learn it by converting an example from Jags.jl (which I have used before). I have selected the ‘rats’ example from Jags.jl for the test.
I have a working code, but 2 issues that I would like some help with…

Here is my version in Turing.jl, excluding the data definition:

using StatsPlots, Turing

@model function rat_sampler(x, Y)
	# input data properties
	xBar = mean(x)
	N = size(Y, 1)

	# sample priors
	alphaC ~ Normal(0, 100)
	tauAlpha ~ Gamma(0.001, 1000)
	betaC  ~ Normal(0, 100)
	tauBeta ~ Gamma(0.001, 1000)
	tauC ~ Gamma(0.001, 1000)
	# computed values - should be stored
	sigma = 1.0/sqrt(tauC)
	alpha0 = alphaC - betaC * xBar

	# sampling alpha and beta outside of the loop, since inside we would
	# have to access alpha[i] and beta[i], which would require defining
	# alpha and beta as an empty vector here anyway...
	alpha ~ filldist(Normal(alphaC, sqrt(1/tauAlpha)), N)
	beta ~ filldist(Normal(betaC, sqrt(1/tauBeta)), N)

	for i in 1:N
		for j in 1:size(Y, 2)
			mu = alpha[i] + beta[i] * (x[j] - xBar)
			Y[i,j] ~ Normal(mu, sigma)
		end
	end
end

# declare the model instance on all processes
model = rat_sampler(rats["x"], rats["Y"])

# sample 4 chains
chains = sample(model, NUTS(), MCMCThreads(), 10000, 4)

# plot and save results
p = plot(chains[[:tauC, :betaC]])  # plot only selected variables

This works, but I would like to include sigma and alpha0 in the chains, so I see their distributions in plot(chains) - as is the case in the Jags.jl example. Is there some way to instruct Turing to add them to the chain?
On the other hand, I do not need alphas and betas - is there some equivalent to Jags’ “monitor” functionality?

The other concern is speed: the Jags.jl version takes under 2 minutes, while the presented Turing version takes ca. 7 minutes. I have also made a version using distributed sampling (as described in Turing documentation), but that saves only ca. 1 minute, so it is still 3 times slower than Jags.jl.

I tried other samplers, but NUTS is the only one I managed producing reasonable answers with.

PS: is there a Turing.jl equivalent to Jags’ “thin” argument, for using only every n-th sample?

You can use MCMCChains.jl to plot and monitor chains. Your timing comparison is not fair. The Gibbs sampler in Jags will be faster per iteration than the NUTS sampler in Turing, but the mixing will be much better with NUTS. Hence, you should first compare Effective Sample Size which will tell you how many iterations you need with NUTS.

2 Likes