How to get better performance from NUTS in Turing?

I’m wondering if anyone has further performance tips to speed this up?

This MWE is an attempt to use NUTS on data that is similar in size to my actual use case, following the model introduced in this post, but no longer assuming that multivariate parameters are identical in size.

The issue is that even with reverse mode AD and cacheing, and explicitly specifying all variables to avoid any loops in the hierarchy, it estimates about an hour for 200 iterations for the example below.

using Pkg;Pkg.activate("./")
using Revise,Test
using DataFrames,CSV,Query,FreqTables
using GCom,SeuratRDS,PCquery,BioPathPlots
using LightGraphs, MetaGraphs
using StatsBase,Distributions,Turing

using Turing,Zygote,Tracker,ReverseDiff
using Memoization
using MLDataUtils: shuffleobs, stratifiedobs, rescale!

# Turing.setadbackend(:zygote) # error
# Turing.setadbackend(:tracker) # error
Turing.setadbackend(:reversediff)

# tape cacheing (https://turing.ml/dev/docs/using-turing/autodiff#switching-ad-modes)
Turing.setrdcache(true)

@model function grouped_lasso_multi(y, X, σ, λ², mk, ::Type{T} = Float64) where {T}
	# number of observations and features
	p, nobs = size(X) # p is the number of features, assume that columns of x are blocked like [lig^{(T=1,P)}',path^{(T=1,P)}', ..., lig^{(nT=1,P)}',path^{(T=nT,P)}']

	# set variance prior (shrinkage of the group-wise linear coefficients)

	τ²₁ ~ Gamma((mk[1] + 1) / 2, 2 / λ²)
	τ²₂ ~ Gamma((mk[2] + 1) / 2, 2 / λ²)
	τ²₃ ~ Gamma((mk[3] + 1) / 2, 2 / λ²)
	τ²₄ ~ Gamma((mk[4] + 1) / 2, 2 / λ²)
	τ²₅ ~ Gamma((mk[5] + 1) / 2, 2 / λ²)
	τ²₆ ~ Gamma((mk[6] + 1) / 2, 2 / λ²)

	β₁ ~ MvNormal(mk[1], σ .* sqrt.(τ²₁))
	β₂ ~ MvNormal(mk[2], σ .* sqrt.(τ²₂))
	β₃ ~ MvNormal(mk[3], σ .* sqrt.(τ²₃))
	β₄ ~ MvNormal(mk[4], σ .* sqrt.(τ²₄))
	β₅ ~ MvNormal(mk[5], σ .* sqrt.(τ²₅))
	β₆ ~ MvNormal(mk[6], σ .* sqrt.(τ²₆))

	# set the target distribution
	for i in 1:nobs
		mu = view(X, :, i)' * vec([β₁...,β₂...,β₃...,β₄...,β₅...,β₆...])
		y[:, i] ~ MvNormal(fill(mu,size(y)[1]), Matrix(σ*I, size(y)[1], size(y)[1]))
	end
end


ntrain = 7000
yTrain = rand(MvNormal(fill(0,3),1),ntrain)
x0Train = rand(MvNormal(fill(0,3),1),ntrain)
x1Train = rand(MvNormal(fill(0,5),1),ntrain)
x2Train = rand(MvNormal(fill(0,10),1),ntrain)
x3Train = rand(MvNormal(fill(0,20),1),ntrain)
x4Train = rand(MvNormal(fill(0,15),1),ntrain)
x5Train = rand(MvNormal(fill(0,11),1),ntrain)

xTrain = [x0Train; x1Train; x2Train; x3Train; x4Train; x5Train;]

σ = 1.0 #
λ² = 2
mk = [3,5,10,20,15,11]

model = grouped_lasso_multi(yTrain,xTrain,σ,λ²,mk)
chn = sample(model, NUTS(), Sd[:iter])