I have a spatial model on a (currently) 5x5 grid, for two variables (total size 3737). And 483 observations for one variables. In this simple setting, I want to work in full space (no space reduction), and model my system as a multivariate normal distribution (I then want to validate the output against an ensemble Kalman Filter before making the model more complex). Here is a minimalist, working example of my Turing.jl model:
using Turing
using Distributions
using LinearAlgebra
n = 3737
N = 50
nobs = 483
ensemble = reshape(rand(filldist(Normal(0, 1), n*N)), n, N)
mean_ = mean(ensemble, dims=2)[:, 1]
cov_ = cov(ensemble, dims=2)
L = cholesky(cov_ + Diagonal(1e-6 * ones(n))).L
sst = randn(nobs)
sst_err = abs.(randn(nobs)*0.1) .+ 0.1
idx = collect(1:nobs)
@model function simplemodel(sst)
random_scales ~ filldist(Normal(0, 1), n)
full_field = mean_ .+ L * random_scales
sst_field = full_field[idx]
sst ~ MvNormal(sst_field, sst_err)
end
chain = sample(simplemodel(sst), NUTS(), 1000)
Unfortunately, it takes forever without even starting. On the other hand, an equivalent pymc
implementation takes about 9 min to complete (with real data). And 20 min via pycall. Here the pymc model called via pycall:
function run_pymc_model(mean, chol, idx, sst_err, observed_sst)
py"""
import pymc as pm
import numpy as np
def build_and_run_model(mean, chol, idx, sst_err, observed_sst):
# Define the PyMC model
with pm.Model() as model:
random_scales = pm.Normal('random_scales', mu=0, sigma=1, shape=mean.shape[0])
full_field = mean + chol @ random_scales
sst_observable = full_field[idx]
pm.Normal('sst', mu=sst_observable, sigma=sst_err, observed=observed_sst)
trace = pm.sample()
return trace
"""
build_and_run_model = pyimport("__main__").build_and_run_model
trace = build_and_run_model(
mean,
chol,
idx .- 1, # Adjust for Python's 0-based indexing
sst_err,
observed_sst)
return trace
end
trace = run_pymc_model(mean_, L, idx, sst_err, sst)
Calling python from julia for performance seems to defy the purpose of using julia in the first place. Hence my question: is there a better way of implementing my julia model to match pymc performance?