GP-LVM with Turing?

I saw some of the recent posts these last few weeks about implementing gaussian processes in Turing. I’ve really been wanting to play around with GP-LVMs for a while (in particular, as part of a larger model), so I took a stab at trying to implement the basics of one using those recently-posted examples.

The issue is that the model is extremely slow. I’ve worked with GPs enough (including GPs with latent inputs) on my own to have realistic expectations, but the speed for even a very small model was much slower than I would have expected.

I don’t know that I expect a fix or anything, but I thought I would this post in case anyone is curious about this use case, perhaps for future testing/benchmarking. The code below isn’t a fully proper GP-LVM (we’d probably want to use the ARD transform from KernelFunctions for that), but does the basic operations involved, afaik.

With only 50 data points (on three dimensions), it took about 100 minutes to run using NUTS. Is NUTS just a bad fit for this kind of model?

using Turing, KernelFunctions, ReverseDiff, LinearAlgebra, BenchmarkTools
Turing.setadbackend(:reversediff)

sekernel(alpha, rho) = alpha^2 * KernelFunctions.transform(SEKernel(), sqrt(0.5)/rho)

@model function GPLVM(y1, y2, y3, jitter=1e-6)
  N, = size(y1)
  P = 3
  
  # Priors
  X ~ filldist(Normal(0, 1), N)
  alpha ~ LogNormal(0, 0.1)
  rho ~ filldist(LogNormal(0, 0.1), P)
  sig2 ~ filldist(LogNormal(0, 1), P)
  
  # GP Covariance matrix
  kernel1 = sekernel(alpha, rho[1])  # covariance function
  kernel2 = sekernel(alpha, rho[2])  # covariance function
  kernel3 = sekernel(alpha, rho[3])  # covariance function

  K1 = kernelmatrix(kernel1, X)  # cov matrix
  K2 = kernelmatrix(kernel2, X)  # cov matrix
  K3 = kernelmatrix(kernel3, X)  # cov matrix

  K1 += I * (sig2[1] + jitter)
  K2 += I * (sig2[2] + jitter)
  K3 += I * (sig2[3] + jitter)

  # Sampling Distribution.
  y1 ~ MvNormal(zeros(N), K1)
  y2 ~ MvNormal(zeros(N), K2)
  y3 ~ MvNormal(zeros(N), K3)
end

y = randn(30, 3)
@time gp = GPLVM(y[:,1], y[:,2], y[:,3])
@time chain = sample(gp, NUTS(0.65), 2000)

While I would like to be able to use GP-LVMs as part of larger models in Turing, if anyone has suggestions for fitting them in general (e.g., alternative packages), I’d certainly be interested. I can do them by hand, of course, but I’m looking for things to make it easier to use and write less of my own code.

This could be ReverseDiff-related. Could you benchmark this with the various AD backends and report on the results?

Sure. Will post results later.

Before I posted this, I swear I tried with and without the setrdcache setting and didn’t see much difference, but enabling it this time, it is indeed making a large difference in speed.

With reversediff, for 30 data points with 3 dimensions, running NUTS for 1000 iterations takes less than 2 minutes with rdcache enabled, and 7 minutes without. Forwarddiff takes about 1 min (but I would expect it to fall behind quickly as sample size grows). Tracker gives an error.

Those times don’t sound bad at all, but fwiw, using the iris data set last night without memoization (~150 data points, 4 dimensions) took over 10 hours for 2000 samples. I know how badly GPs scale, but that was still a good bit worse than I was (perhaps naively?) expecting. I may test again with slightly larger sample sizes, just to see how things scale (out of curiosity).

Ah interesting. Could you try with Zygote as well please?

Those times don’t sound bad at all, but fwiw, using the iris data set last night without memoization (~150 data points, 4 dimensions) took over 10 hours for 2000 samples. I know how badly GPs scale, but that was still a good bit worse than I was (perhaps naively?) expecting. I may test again with slightly larger sample sizes, just to see how things scale (out of curiosity).

I agree that this sounds excessive. To really know whether or not it is, it would be helpful if you could estimate the time per gradient evaluation – since NUTS is adaptive, it’s hard to know what’s really going on with your timings.

OK, re-ran some models with different AD backends. I upped the sample to 60 (still with three dimensions), and did 1000 iterations with NUTS.

ReverseDiff (no rdcache): 3422 seconds (one run)
ReverseDiff (with rdcache): 730-760 seconds (two runs)
ForwardDiff: 339-348 seconds (two runs)
Zygote: 138-147 seconds (two runs)

So Zygote is by far the fastest, and the standard reversediff (with no memoization) is… not.

2 Likes