NUTS speed is very slow for high dimension parameter inference in Turing.jl

Hi, I’m trying Turing.jl for a high dimension hyperparameter inference problem, and here’s a toy example. To estimate tureE using fake testx and testy, I’m using a NUTS which costs no less than 5 hours according to the progress meter. I’ve changed the backed to reversediff and used a multivariate normal distribution. Was wondering if I made any mistakes to build the model? Many thanks.

testx = rand(1000,10000) 
trueE = rand(10000)
testy = testx*trueE
@model toy(x,y) = begin
    sig ~ InverseGamma(1,1)
    M ~ MvNormal(10000,sig)
    y ~ MvNormal(x*M,sig)
end
model = toy(testx,testy)
@time chain = sample(model, NUTS(0.65), 1000);
1 Like

No obvious mistakes. In general though, you should not expect large models to sample quickly. 5 hours to sample 10,000 parameters does not sound unrealistic. For models like this though, where the compute time will be dominated by computing the gradient of the matrix-vector multiplication, I expect Zygote will perform better than ReverseDiff.

1 Like

What about Turing.setrdcache(true) in this case?

This worked well for me in a different situation:

Two things.

AD Backend First, I think the default backend for Turing is still forwarddiff, which should be very inefficient for a high-dimensional model like yours. I just tried your model on my laptop with reversediff as the AD backend and progressmeter shows 2:30:16,

Hierarchical Prior Second, InverseGamma(1,1) is too weak and should result in hard-to-navigate tails. This will result in NUTS choosing long integration trajectories, and the length of the trajectory is roughly proportional to the iteration complexity. The motivation behind people using the inverse gamma is simply because it is a conjugate prior to the normal, which is irrelevant to MCMC, so we’re free to use alternative priors. A better choice is to use a more informative prior with lighter tails like the truncated(Normal(0, 10), 0, Inf). This quickly reduced the projected sampling time to 1:20:45. See the prior choice wiki for an up-to-date recommendation list by the Stan people.

Forcing Short Trajectories The last measure would be to reduce the max tree depth as NUTS(0.65, max_depth=8). Combined with the informative prior above, progressmeter shows 0:40:18. The default parameter is 10 which results in a maximum of 2^10 leapfrog steps while 8 will result in 2^8. This measure, however, will negatively affect the statistical efficiency of the sampler, it’s recommended to instead fix the model so that the sampler does not hit the maximum limit. (But problems that are fundamentally hard to infer do exist, like sparse regression, stochastic volatility, etc… These are an open challenge to modern inference algorithms. So as an end-user, there is not much we can do about these…)

One of the difficulties with the current Bayesian workflow is that model design is not entirely independent from inference. It actually strongly affects the sampler’s performance both statistically and computationally. So you should tweak the model so that NUTS can do it’s job quickly and efficiently. Mike Betancourt’s blog have a lot of good guidelines on these aspects. See for example: Identity Crisis

6 Likes

Many thanks for your detailed reply and it indeed reminds me of the pain points of current common Bayesian methods which I should spend more time to improve the model itself. I guess there are not many things I could do currently with Turing’s samplers.

Indeed after I changed to Zygote, I got a worse speed and I’m not clear how Turing choose to calculate the log joint and seems it tried to calculate the fully conditional posterior given any specific prior which is not an easy job for a high dimensional data.

It helps indeed and many thanks for your suggestion but the bottleneck here is probably the model itself as discussed in other posts here.

Do you care about the posterior on M or do you just need to marginalize it and only care about sig? If the latter, this is a perfect example problem for https://cosmicmar.com/MuseInference.jl. You can beat the NUTS runtime by 100X or more.

The MUSE algorithm gives a Gaussian approximation to the marginal sig posterior, but is very accurate for high-dimensional latent spaces (due to central limit theorem) and is exact up to MC error if the likelihood is Gaussian (your case happens to be both).

Here’s an example code where I reduced the latent dimensionality to 1000 to give something easier to run. I’m getting about 100X faster than NUTS, and the relative improvement will get even more as you increase the dimensionality to your original 10000.

using MuseInference, Turing, Zygote, PyPlot
Turing.setadbackend(:zygote)

x = rand(1000,1000)
trueE = rand(1000)
testy = x * trueE
# you'll have to define the model without the observations
# in the arguments and instead use newer-style conditioning via `|`
@model function toy()
    sig ~ InverseGamma(1,1)
    M ~ MvNormal(1000,sig)
    y ~ MvNormal(x*M,sig)
end
model = toy() | (y=testy,)

# NUTS (~1000sec)
@time chain = sample(
    model, NUTS(100, 0.65), 100, progress=true
)

# MUSE (~10sec)
result = muse(
    model, (sig=0.5,), get_covariance=true,
    nsims=30, θ_rtol=1e-1, ∇z_logLike_atol=1,
)

# comparison plot
hist(chain[:sig], density=true)
sigs = range(xlim()...,length=1000)
plot(sigs, pdf.(result.dist, sigs))

plot_86

MUSE takes a starting guess for the sig value, which you can refine as you do more runs. The main parameters are the number of sims and tolerances. The error on the estimated mean relative to the uncertianty goes like 1/sqrt(nsims), so the like-to-like comparison sets this to the ESS of the chain, which I did above. θ_rtol is a solver tolerance on sig relative to its uncertainty and ∇z_logLike_atol is the absolute tolerance on the gradient w.r.t. M which appears in an internal maximization over M that happens.

Let me know if you try it out if you run into any issues!

4 Likes

So sorry for the late reply due to the final week. I appreciate a lot for this interesting approximation methods but I’m afraid the posterior of M are what I’m interested in and seems I couldn’t run your sample code in Julia 1.6 environment.