Slow Turing.jl sampling compared to python pymc

I should have started with Turing Performance Tips.
I am now using NUTS(adtype=AutoReverseDiff(true)) for the sampling, and define the model types precisely (the latter, as suggested by @garrett, was not enough to get the sampling started):

import ReverseDiff

@model function simplemodel(sst::Vector{Float64}, mean_::Vector{Float64}, L::LowerTriangular{Float64, Matrix{Float64}}, idx::Vector{Int64})
    n = length(mean_)
    random_scales ~ filldist(Normal(0, 1), n)
    full_field = mean_ .+ L * random_scales
    sst_field = full_field[idx]
    sst ~ MvNormal(sst_field, sst_err)
end

chain = sample(simplemodel(sst, mean_, L, idx), NUTS(adtype=AutoReverseDiff(true)), 1000)

finally kicks start the sampling. The toy example takes 4 min. I’ll try it on my real problem, but if that holds that would be faster than python.

EDIT: that works. 3 min 17 sec for one chain. Perfect. NOTE: comparison with pymc is beyond the point (for that I’d need a proper linkage to the BLAS library, or use other, faster samplers like jax etc).

EDIT2: surprisingly, sampling the prior is now causing me trouble:
prior_chain = sample(simplemodel(sst, ...), Prior(), 1000, progress=true) takes a fraction of seconds to complete the sampling, but full 6 mins to return a result.

EDIT3: for the prior sampling issue, I ended up adding return full_field to my model, and I now do priormodel = simplemodel(missing, ...); prior_sst_samples = hcat([priormodel() for _ in 1:1000]...) instead of sampling via sample(...). That takes 3 seconds, instead of 6 min with sample.

EDIT4: the Prior sampling issue was mostly solved in 0.39.3 (Prior() sampling takes longer than NUTS() · Issue #2604 · TuringLang/Turing.jl · GitHub)