Hi everybody,
I’ve coded up a multilevel model with random intercepts and slopes in Turing (see below). However, my model is very slow compared to similar models in Stan. Does anybody have any insights on how to speed things up further?
Thanks
using Turing
using Random
using MCMCChains
using LinearAlgebra: I
using StatsPlots
@model function linreg(X, y, idx; P = size(X,2), G = length(unique(idx)))
# prior for overall noise
σ ~ Exponential(1)
# priors over means of group-level intercept and slopes
α ~ Normal(0, 2.5)
β ~ filldist(Normal(0, 2.5), P)
# priors over means of subject level intercept and slopes
αₚ ~ filldist(Normal(), G)
βₚ ~ filldist(Normal(), P, G)
# priors over stds of subject level intercept and slopes
σₐ ~ Exponential(1)
σᵦ ~ filldist(Exponential(1), P)
# construct regression coefficients
α_ = α .+ αₚ .* σₐ
β_ = β .+ βₚ .* σᵦ
# likelihood
μ = α_[idx] + sum(X .* β_[:,idx]', dims=2)
y ~ MvNormal(vec(μ), σ^2 * I)
end;
# generate some synthetic data to test the model
Random.seed!(123)
N = 100
G = 10
α = 5
β = [-2, 3, -4, 5]
σ = .1
σₐ = 2
σᵦ = [1 2 .1 .5]
αₚ = rand(Normal(), G).*σₐ .+ α
βₚ = rand(MvNormal(zeros(length(β)), I), G).*σᵦ' .+ β
X = randn(N*G, length(β))
idx = collect(repeat(1:G, inner=N))
μ = αₚ[idx] + sum(X .* βₚ[:,idx]', dims=2)
y = rand(MvNormal(vec(μ), σ^2 * I))
model = linreg(X, y, idx)
chn = sample(model, NUTS(), 100)
plot(group(chn, :β))
plot(group(chn, :σᵦ))