Excessive memory consumption when optimising Gaussian process regression with automatic reverse differentiation

I am working on a model involving Gaussian processes and I am optimising it using Optim.jl while obtaining gradients via DifferentiationInterface.jl. Unfortunately, when I run the code, the top utility informs me that the code uses a lot of memory.

While investigating, I noticed that even the simple model of Gaussian process regression uses a lot of memory when using Optim.jl with DifferentiationInterface.jl. I post some code below as a MWE that can be copy-pasted and executed (provided packages are available):

using DifferentiationInterface
using Distributions
using LinearAlgebra
import Mooncake
using Optim
using Random

function gp(x, y; iterations = 1)

    # Get number of data items
    N = length(y)

    # Allocate once zero vector necessary for marginal likelihood 
    zeromean = zeros(N)

    # Initialise parameters randomly
    rng = MersenneTwister(1234)

    initialsol = randn(rng, 3)
    
    # pre-allocate N×N covariance matrix K
    K = zeros(N, N)

    # setup optimiser options
    opt = Optim.Options(iterations = iterations, show_trace = true, show_every = 1)

    helper(p) = negativemarginallikelihood_gp(p, K, x, y, zeromean)

    # use DifferentiationInterface to get gradients

    backend = AutoMooncake(config = nothing)
    
    prep = prepare_gradient(negativemarginallikelihood_gp, backend, initialsol, Cache(K), Constant(x), Constant(y), Constant(zeromean))   

    gradhelper!(grad, p) = DifferentiationInterface.gradient!(negativemarginallikelihood_gp, grad, prep, backend, p, Cache(K), Constant(x), Constant(y), Constant(zeromean))

    optimize(helper, gradhelper!, initialsol, ConjugateGradient(), opt).minimizer

end

# Negative marginal likelihood function of gp
function negativemarginallikelihood_gp(p, K, x, y, zeromean)

    N = length(y)

    θ = exp.(p) # make parameters positive

    # Calculate covariance matrix
    # Better implementation are of course possible 
    for m in 1:N
        for n in 1:N
           K[n, m] =  θ[1] * exp(-0.5 * abs2(x[n] - x[m])/ θ[2])
        end
    end

    # add jitter on diagonal
    for n in 1:N
        K[n, n] += 1e-6
    end

    # Return negative log marginal likelihood.
    # We want to minimise this.
    return -logpdf(MvNormal(zeromean, K + θ[3]*I), y)

end

We can run the code with fake data:

x = randn(MersenneTwister(1234), 1000);
y = sin.(x) + 0.01*randn(MersenneTwister(1234), 1000);
gp(x,y; iterations = 1) # warmup
gp(x,y; iterations = 100) # top utility tells me that Julia uses ~8-9% of my 32GB memory.

If I try to run the code with 10_000 data items like below:

x = randn(MersenneTwister(1234), 10_000);
y = sin.(x) + 0.01*randn(MersenneTwister(1234), 10_000);
gp(x,y; iterations = 100)

Julia will run out of memory and crash.

My questions:

  • Is there something wrong with the above code?
  • Alternatively: does anyone have any successful examples of optimising a Gaussian process regression model using automatic reverse differentiation?

Update: I should clarify that I am not tied to reverse differentiation and I will use anything that does the job. However, the objective I am optimising (in my actual problem, not the MWE above) is a scalar function f(x):Rᴹ→R, where M are the number of free parameters and M is typically in the order of 5000 to 15000 depending on the dataset.

Hi @Nikos_Gianniotis

I am a GP researcher and not a Mooncake-er or Optim-er, so there are some details I won’t address here. If reverse mode is a hard requirement, as a generic thought I would guess you want to be really careful that you are hitting some kind of rule and not just manually doing reverse mode AD on a Cholesky factorization. Maybe that is already happening, but just a general suggestion of something to check.

For fun, though, I tried modifying your code to use a Vecchia approximation with Vecchia.jl, a package I develop. For a well-behaved covariance function (in the sense of having a nice screening effect) and a one-dimensional process, these approximations can be made effectively exact. They correspond to approximating \mathbf{\Sigma}^{-1} with a sparse matrix, which in the 1D case will in simple configurations be banded around the origin. A good introductory reference is here (although I would recommend skimming this as an intro and not implementing SGV, which I can elaborate on if you want). But think of it like a weakened Markov-like assumption.

Anyways. The upshot is that you write your big likelihood approximation as a sum of many small likelihood approximations. And so if you are a little careful in the implementation even as the data size n grows you use O(1) memory. So here is something that runs for 10k points and will also run for 10M points:

using DifferentiationInterface
using Distributions
using LinearAlgebra
import Mooncake
import ReverseDiff
using Optim
using Random

using Vecchia, BesselK, StaticArrays

# Note 1: this is not a the squared exponential kernel, will explain below.
matern_12(x, y, p) = matern(x, y, (p[1], p[2], 1/2))

# Note 2: this data setting is a bit simpler than the usual one in Vecchia.jl, so 
# to save on compute I just wrote a simple little constructor for the config.
function gen_cfg(yv, xv, m)
  c = [collect(max(1, j-m):(j-1)) for j in 1:length(xv)]
  Vecchia.VecchiaConfig(matern_12, hcat.(yv), [[SA[x]] for x in xv], c)
end

function gp(x, y; iterations = 1)
    # Get number of data items
    N = length(y)
    # Initialise parameters randomly
    rng = MersenneTwister(1234)
    initialsol = [1.0, 0.01] #randn(rng, 3)
    # instead: create Vecchia configuration.
    sp  = sortperm(x)
    cfg = gen_cfg(y[sp], x[sp], 5)
    # negative log-likelihood function:
    data_nll = p -> nll(cfg, p)
    # setup optimiser options
    opt = Optim.Options(iterations = iterations, show_trace = true, show_every = 1)
    backend = AutoReverseDiff()
    prep = prepare_gradient(data_nll, backend, initialsol)   
    grad_data_nll!(grad, p) = DifferentiationInterface.gradient!(data_nll, grad, prep, backend, p)
    optimize(data_nll, grad_data_nll!, initialsol, ConjugateGradient(), opt).minimizer
end

x = randn(MersenneTwister(1234), 10000);
y = sin.(x) + 0.01*randn(MersenneTwister(1234), 10000);
gp(x,y; iterations = 1)

A few other notes though:
– Unless you think your process is actually analytic, a squared exponential kernel is probably not an advisable choice. Something that is as cheap to within a few ns is to use a half-integer order Matern kernel. The Matern with v=5/2 gives you a process with two mean-square derivatives.
– Your random initialization is probably also not doing you any favors. For many points on [0,1], it is pretty easy to pick a range parameter for which that matrix is numerically indefinite. I see you add a hefty nugget to the matrix for numerical reasons, but that is always a painful choice because you are throwing away a lot of local information by doing so. I can elaborate on this if you want, but you didn’t ask so this is already an obnoxiously long comment.
– You probably need to tell Optim about some box constraints here. I didn’t see how to do that in ten seconds of googling, but I’m sure it is possible. And you will naturally get errors if it picks a negative range parameter or something.
– I was getting an error with Mooncake saying I was ccall-ing something, which I guess I am if you count LAPACK calls like cholesky!, but otherwise am not. I switched to AutoReverseDiff and everything is fine. But there may be some tiny compat stuff to do with Mooncake.

Thanks for taking the time to reply.

That’s a good point, but I am unsure on how to deal with it. How would I go about either detecting or even resolving something like this?

By the way, I should clarify that reverse mode is not a hard constraint, but in my actual problem I optimise a scalar objective function with respect to a few thousand free parameters so I thought reverse mode would be more appropriate.

Thanks for linking to your package and for the reference. I am a fan of sparse gps via inducing points, but I always wanted to read up on the Vecchia approximation. I didn’t know that a Julia package existed.

Good points. However, better initialisation or a better kernel will not help me reduce the memory consumption. Also, this is an MWE so I avoided making the code long.

Interesting. When I execute the code in the original post, I don’t get any such complaints despite using MvNormal which must use cholesky in its internals.

Thanks again for your reply.