Slow Turing.jl sampling compared to python pymc

I have a spatial model on a (currently) 5x5 grid, for two variables (total size 3737). And 483 observations for one variables. In this simple setting, I want to work in full space (no space reduction), and model my system as a multivariate normal distribution (I then want to validate the output against an ensemble Kalman Filter before making the model more complex). Here is a minimalist, working example of my Turing.jl model:

using Turing
using Distributions
using LinearAlgebra

n = 3737
N = 50
nobs = 483

ensemble = reshape(rand(filldist(Normal(0, 1), n*N)), n, N)
mean_ = mean(ensemble, dims=2)[:, 1]
cov_ = cov(ensemble, dims=2)
L = cholesky(cov_ + Diagonal(1e-6 * ones(n))).L
sst = randn(nobs)
sst_err = abs.(randn(nobs)*0.1) .+ 0.1
idx = collect(1:nobs)

@model function simplemodel(sst)
    random_scales ~ filldist(Normal(0, 1), n)
    full_field = mean_ .+ L * random_scales
    sst_field = full_field[idx]
    sst ~ MvNormal(sst_field, sst_err)
end

chain = sample(simplemodel(sst), NUTS(), 1000)

Unfortunately, it takes forever without even starting. On the other hand, an equivalent pymc implementation takes about 9 min to complete (with real data). And 20 min via pycall. Here the pymc model called via pycall:

function run_pymc_model(mean, chol, idx, sst_err, observed_sst)
    py"""
    import pymc as pm
    import numpy as np

    def build_and_run_model(mean, chol, idx, sst_err, observed_sst):
        # Define the PyMC model
        with pm.Model() as model:
            random_scales = pm.Normal('random_scales', mu=0, sigma=1, shape=mean.shape[0])
            full_field = mean + chol @ random_scales
            sst_observable = full_field[idx]
            pm.Normal('sst', mu=sst_observable, sigma=sst_err, observed=observed_sst)
            trace = pm.sample()
        return trace
    """

    build_and_run_model = pyimport("__main__").build_and_run_model

    trace = build_and_run_model(
        mean,
        chol,
        idx .- 1,  # Adjust for Python's 0-based indexing
        sst_err,
        observed_sst)

    return trace
end

trace  = run_pymc_model(mean_, L, idx, sst_err, sst)

Calling python from julia for performance seems to defy the purpose of using julia in the first place. Hence my question: is there a better way of implementing my julia model to match pymc performance?

Hi @mahe ,

If you are simply interested in geospatial Gaussian processes, take a look into GeoStats.jl. It has efficient simulation methods for grids with millions of cells:

Thanks @juliohm I’ll take a look. I’m still interested in solving this though.

My first idea would be to pass the global variables (mean_, L, …) as arguments to simplemodel. That should provide some speedup.. There might be some other optimizations you could do with Turing, but that would be the first thing I would try.

Thanks @garrett , I’m now trying

@model function simplemodel(sst::Vector{Float64}, mean_::Vector{Float64}, L::LowerTriangular{Float64, Matrix{Float64}}, idx::Vector{Int64})
    n = length(mean_)
    random_scales ~ filldist(Normal(0, 1), n)
    full_field = mean_ .+ L * random_scales
    sst_field = full_field[idx]
    sst ~ MvNormal(sst_field, sst_err)
end

chain = sample(simplemodel(sst, mean_, L, idx), NUTS(), 1000)

let’s see…

I should have started with Turing Performance Tips.
I am now using NUTS(adtype=AutoReverseDiff(true)) for the sampling, and define the model types precisely (the latter, as suggested by @garrett, was not enough to get the sampling started):

import ReverseDiff

@model function simplemodel(sst::Vector{Float64}, mean_::Vector{Float64}, L::LowerTriangular{Float64, Matrix{Float64}}, idx::Vector{Int64})
    n = length(mean_)
    random_scales ~ filldist(Normal(0, 1), n)
    full_field = mean_ .+ L * random_scales
    sst_field = full_field[idx]
    sst ~ MvNormal(sst_field, sst_err)
end

chain = sample(simplemodel(sst, mean_, L, idx), NUTS(adtype=AutoReverseDiff(true)), 1000)

finally kicks start the sampling. The toy example takes 4 min. I’ll try it on my real problem, but if that holds that would be faster than python.

EDIT: that works. 3 min 17 sec for one chain. Perfect. NOTE: comparison with pymc is beyond the point (for that I’d need a proper linkage to the BLAS library, or use other, faster samplers like jax etc).

EDIT2: surprisingly, sampling the prior is now causing me trouble:
prior_chain = sample(simplemodel(sst, ...), Prior(), 1000, progress=true) takes a fraction of seconds to complete the sampling, but full 6 mins to return a result.

EDIT3: for the prior sampling issue, I ended up adding return full_field to my model, and I now do priormodel = simplemodel(missing, ...); prior_sst_samples = hcat([priormodel() for _ in 1:1000]...) instead of sampling via sample(...). That takes 3 seconds, instead of 6 min with sample.

Kind of strange with the prior sampling, but I’m glad to hear the posterior sampling is working better!

I wrote an issue with full reproductible example…