NUTS speed is very slow for high dimension parameter inference in Turing.jl

Do you care about the posterior on M or do you just need to marginalize it and only care about sig? If the latter, this is a perfect example problem for https://cosmicmar.com/MuseInference.jl. You can beat the NUTS runtime by 100X or more.

The MUSE algorithm gives a Gaussian approximation to the marginal sig posterior, but is very accurate for high-dimensional latent spaces (due to central limit theorem) and is exact up to MC error if the likelihood is Gaussian (your case happens to be both).

Here’s an example code where I reduced the latent dimensionality to 1000 to give something easier to run. I’m getting about 100X faster than NUTS, and the relative improvement will get even more as you increase the dimensionality to your original 10000.

using MuseInference, Turing, Zygote, PyPlot
Turing.setadbackend(:zygote)

x = rand(1000,1000)
trueE = rand(1000)
testy = x * trueE
# you'll have to define the model without the observations
# in the arguments and instead use newer-style conditioning via `|`
@model function toy()
    sig ~ InverseGamma(1,1)
    M ~ MvNormal(1000,sig)
    y ~ MvNormal(x*M,sig)
end
model = toy() | (y=testy,)

# NUTS (~1000sec)
@time chain = sample(
    model, NUTS(100, 0.65), 100, progress=true
)

# MUSE (~10sec)
result = muse(
    model, (sig=0.5,), get_covariance=true,
    nsims=30, θ_rtol=1e-1, ∇z_logLike_atol=1,
)

# comparison plot
hist(chain[:sig], density=true)
sigs = range(xlim()...,length=1000)
plot(sigs, pdf.(result.dist, sigs))

plot_86

MUSE takes a starting guess for the sig value, which you can refine as you do more runs. The main parameters are the number of sims and tolerances. The error on the estimated mean relative to the uncertianty goes like 1/sqrt(nsims), so the like-to-like comparison sets this to the ESS of the chain, which I did above. θ_rtol is a solver tolerance on sig relative to its uncertainty and ∇z_logLike_atol is the absolute tolerance on the gradient w.r.t. M which appears in an internal maximization over M that happens.

Let me know if you try it out if you run into any issues!

4 Likes