Large memory consumption when using Mooncake via DifferentiationInterface for Gaussian process optimisation

Hello everyone,

I am working on a model called the Gaussian process latent variable model (GPLVM). The GPLVM is a dimensionality reduction method which given a high dimensional dataset Y will return low-dimensional projections X.

In order to optimise the free parameters, I use Optim.jl in conjunction with DifferentiationInterface.jl and Mooncake.jl that automatically calculate the gradients for me. While I am happy with this setup, I have noticed a very high memory consumption: the top utility informs me that more than 20% of my 32GB memory is in use, and this just for a modest number of data items. While in principle I don’t mind this, if I increase the number of data items, more memory is consumed and the system will kill my julia session.

My actual code is quite lengthy. For the purposes of this question, I have cut it down significantly, but it still exceeds the length of a typical MWE. The code is organised in three functions:

  • function gplvm is called by the user. It sets up the optimiser options, sets up the gradient, calls the optimiser and finally returns the low dimensional coordinates X inferred by the GPLVM model.
  • function negativemarginallikelihood is the objective function to be minimised with respect to the coordinates X , the parameters θ of the Gaussian process kernel and the noise variance σ². If you are not familiar with the GPLVM, but know about Gaussian process, then this resembles very strongly the negative marginal log-likelihood of a regular Gaussian process regression model.
  • function unpack_gplvm is an auxiliary function that converts the vector of free, unconstrained parameters p to the parameters X, θ and σ².

 using DifferentiationInterface
 using Distances
 using Distributions
 using JLD2
 using LinearAlgebra
 import Mooncake
 using Optim
 using Random


"""
Y are the D×N high-dimensional data points
iterations is the number of iterations of the optimisation algorithm
Q is the dimensionality of the latent space
"""
function gplvm(Y; iterations = 1, Q = 2)

    # Get number of data items
    D, N = size(Y)

    # Allocate once zero vector necessary for marginal likelihood 
    # calculation of zero-mean Gaussian process
    zerovector = zeros(N)

    # Initialise parameters randomly:
    # first Q*N elements are the N latent Q-dimensional projections X
    # next 2 elements are kernel parameters - take log here because unpack function uses exp to ensure positivity
    # last parameter is the noise variance  - take log here because unpack function uses exp to ensure positivity
    
    rng = MersenneTwister(1234)

    initialsol = [randn(rng, Q*N)*0.1; log(1.0); log(1.0); log(1.0)]

    # pre-allocate N×N covariance matrix K
    K = zeros(N, N)

    # setup optimiser options
    opt = Optim.Options(iterations = iterations, show_trace = true, show_every = 1)

    # use DifferentiationInterface to get gradients

    # Comment in lines below to use Mooncake and comment out following block that uses Enzyme
    backend = AutoMooncake(config = nothing)
    
    prep = prepare_gradient(negativemarginallikelihood, backend, initialsol, Constant(Y), Constant(zerovector), Cache(K), Constant(D), Constant(Q), Constant(N))   

    gradhelper!(grad, p) = DifferentiationInterface.gradient!(negativemarginallikelihood, grad, prep, backend, p, Constant(Y), Constant(zerovector), Cache(K), Constant(D), Constant(Q), Constant(N))

    helper(p) = negativemarginallikelihood(p, Y, zerovector, K, D, Q, N)


    # Comment in lines below to use Enzyme and comment out above block that uses Mooncake
    # backend = AutoEnzyme()
  
    # helper(p) = negativemarginallikelihood(p, Y, zerovector, K, D, Q, N)
    
    # prep = prepare_gradient(helper, backend, initialsol)
    
    # gradhelper!(grad, p) = DifferentiationInterface.gradient!(helper, grad, prep, backend, p)


    # call actual optimisation
    finalsolution = optimize(helper, gradhelper!, initialsol, ConjugateGradient(), opt).minimizer

    # obtain optimised latent
    X = unpack_gplvm(finalsolution, Q, N)[1]

    # return projections
    return X 

end




# Negative marginal likelihood function of GPLVM.
# We want to minimise this.
function negativemarginallikelihood(p, Y, zerovector, K, D, Q, N)

    # extract parameters from vector p
    X, θ, σ² = unpack_gplvm(p, Q, N)

    # calculate pairwise squared Euclidean distances.
    # Obviously, a more efficient implementation is possible.
    for n in 1:N
        for m in 1:N
           @views K[n, m] = sum((X[:, n] - X[:, m]).^2)
        end
    end

    # ovewrite K entries with covariance matrix elements
    for n in eachindex(K)
        K[n] = θ[1] * exp(-0.5 * K[n] / θ[2])
    end

    # add jitter on diagonal
    for n in 1:N
        K[n, n] += 1e-6
    end

    # accummulate here log likelihood over D dimensions
    accloglikel = zero(eltype(p))

    # instiantiate multivariate normal distribution
    mvn = MvNormal(zerovector, K + σ²*I)

    # iterate over D dimensions
    for d in 1:D

        # calculate log likelihood of d-th dimension
        accloglikel += @views logpdf(mvn, Y[d, :])

    end

    # return negative log marginal likelihood
    -1.0 * accloglikel

end


# Given parameters flattened in p, unpack them into X, θ and σ²
function unpack_gplvm(p, Q, N)

    MARK = 0

    # First Q*N elements are the N latent Q-dimensional projections X
    X = reshape(p[MARK+1:MARK+Q*N], Q, N); MARK += Q*N

    # Next two elements are kernel parameters
    θ = exp.(p[MARK+1:MARK+2]); MARK += 2

    # The last parameter is the noise variance
    σ² = exp(p[MARK+1]); MARK += 1

    return X, θ, σ²

end

To execute the code and observe the high memory consumption, we can simply call it with randomly generated data:

Y = randn(12, 1000) # 1000 data items with 12 features
X = gplvm(Y; iterations = 1) # warmup
X = gplvm(Y; iterations = 100) # note memory consumption during execution

If I run this:

Y = randn(12, 10_000) # 10000 data items with 12 features
X = gplvm(Y; iterations = 100)

the system will kill the julia session.

I strongly believe the that high memory consumption is related to the use of DifferentiationInterface.jl and Mooncake.jl. I thought that the use of contexts would help, but unfortunately it didn’t help. I have gone through the documentation, but I can’t see if I am missing something. Does anyone have any advice on how I could reduce my memory footprint? Thanks for tolerating this very long question.

(For what it’s worth, I have made the above code available here GitHub - ngiann/GPLVM.jl).


Perhaps useful: I am using Julia Version 1.11.4 with Ubuntu 22.04.5 LTS


Update: updated code so that use of Enzyme can be commented in and Mooncake can be commented out. This is in response to comment below.


Update: introduce new function that uses gradient-free NelderMead optimiser:

"""
Same as gplvm but uses the gradient-free NelderMead optimiser.
"""
function gplvm_gradient_free(Y; iterations = 1, Q = 2)

    # Get number of data items
    D, N = size(Y)

    # Allocate once zero vector necessary for marginal likelihood 
    # calculation of zero-mean Gaussian process
    zerovector = zeros(N)

    # Initialise parameters randomly:
    # first Q*N elements are the N latent Q-dimensional projections X
    # next 2 elements are kernel parameters - take log here because unpack function uses exp to ensure positivity
    # last parameter is the noise variance  - take log here because unpack function uses exp to ensure positivity
    
    rng = MersenneTwister(1234)

    initialsol = [randn(rng, Q*N)*0.1; log(1.0); log(1.0); log(1.0)]

    # pre-allocate N×N covariance matrix K
    K = zeros(N, N)

    # setup optimiser options
    opt = Optim.Options(iterations = iterations, show_trace = true, show_every = 1)

    # objective function to be optimised
    helper(p) = negativemarginallikelihood(p, Y, zerovector, K, D, Q, N)

    # call actual optimisation
    finalsolution = optimize(helper, initialsol, NelderMead(), opt).minimizer

    # obtain optimised latent
    X = unpack_gplvm(finalsolution, Q, N)[1]

    # return projections
    return X 

end
1 Like

Thanks for opening up this question. I’m going to have a dig around on my end and see if I can see what’s going on. Will get back to you!

1 Like

I’m seeing very large memory usage on my machine also – will investigate. I agree that it seems at least a little bit odd to see this much memory use.

Thanks for confirming. If I do the optimisation using the gradient-free NelderMead() method, the memory usage stays low. This is the reason why I am suspecting that something doesn’t work as expected when I use the automatic gradients.

So I think what’s going on here is just that, because you’re using reverse-mode AD, you need quite a lot of memory.

My quick back of the envelop calculations suggest that each copy of a matrix of the same size as K when you have 10_000 inputs is almost 1GB. Mooncake (and reverse-mode AD in general, albeit precise numbers will vary from package to package) has to store at least the equivalent of several copies of that in order to operate. So I think there’s a decent chance that what we’re seeing is largely to be expected.

Out of interest, have you tried out Enzyme on this function? I believe it might have some optimisations which mean that it currently uses a bit less memory than Mooncake in some cases.

1 Like

Thanks for getting back. The trouble is that, according to top, I already use more than 1GB of memory for the 1000 data items (that’s the 20% memory consumption I mention above). Running it for 10_000 data items with

Y = randn(12, 10_000) # 10000 data items with 12 features
X = gplvm(Y; iterations = 100)

makes my system run out of memory (I have a total 32GB) and crash julia. I understand that a large covariance matrix of dimensions 10_000×10_000 would require a lot of memory, but I already use a lot for the 1000 data items case and julia crashes for the 10_000 case.

I tried using Enzyme instead (I updated the code above accordingly), but this gives me the following error:

Function argument passed to autodiff cannot be proven readonly.

Thanks for your time. I will continue investigating.

I just made a curious observation concerning memory usage. However, I am not entirely sure if it does indeed relate to my problem.

As I mentioned above, if I start a new Julia session and perform optimisation using the gradient-free NelderMead optimiser, like this:

# replace line `finalsolution = optimize(helper, gradhelper!, initialsol, ConjugateGradient(), opt).minimizer` with
finalsolution = optimize(helper,  initialsol, NelderMead(), opt).minimizer

the memory consumption according to top will be about 3-4%.

However, if I start a new Julia session and first do the optimisation using the automatic gradients and only after that switch to using the gradient-free optimise NelderMead, I still see a high memory usage of about 14%. I would have expected the memory usage to drop again to the previously observed 3%.

I repeated the above multiple times (i.e. restarted julia, re-booted machine) in order to reassure myself that it is not a quirk of some external programme interferring on my linux machine, but I observe this behaviour consistently.

Ah, yes, so this is Mooncake’s caching mechanism doing it’s thing. There’s a bunch of memory that it holds on to in a global variable so that it doesn’t have to reallocate it each time that you differentiate a function that you’ve seen before.

If you call

empty!(Mooncake.get_interpreter().oc_cache) # uses internals -- not part of the public interface (yet)
GC.gc(true)

you should see the memory consumption go back to the usual levels.

I should probably make clearing this cache part of the public interface…

Thanks for this very useful piece of information. So, if I understand correctly, this is tangential to the problem.

Unfortunately, when using

in the REPL, I don’t observe that memory usage returning to the previously observed lower percentage when using NelderMead.

Oh, interesting. Would you mind trying advancing the world age (e.g. by just defining a new function, and checking that Base.get_world_counter() has changed) and then GC-ing? This should also do the trick but, yes, it ought to be orthogonal to the problem. If it doesn’t resolve it I’ll run stuff locally again, and see if I can replicate.

For reproducibility’s sake, I introduced in the top post an additional function called gplvm_gradient_free that uses the NelderMead optimiser. This new function makes no calls to either DifferentiationInterface or Mooncake.

I then execute the code below. My observations concerning memory allocations (i.e. readings of %MEM column in top utility) are in the comments:

Y = randn(MersenneTwister(1234), 12, 1000) # create 1000 data items of 12 dimensions

X = gplvm(Y; iterations = 1) # warmup
X = gplvm(Y; iterations = 10) # I observe in top utility ~20% memory usage  (machine has 32GB memory)


# Run version that uses NelderMead and observe memory usage in top
X = gplvm_gradient_free(Y; iterations = 1)  # warmup
X = gplvm_gradient_free(Y; iterations = 10) # I observe in top utility ~20% memory usage (machine has 32GB memory)


# Call mooncake internals and GC 
empty!(Mooncake.get_interpreter().oc_cache) # uses internals -- not part of the public interface (yet)
GC.gc(true)

X = gplvm_gradient_free(Y; iterations = 10) # observe ~7% memory usage


# Call mooncake internals and GC but also check world counter
empty!(Mooncake.get_interpreter().oc_cache) # uses internals -- not part of the public interface (yet)
GC.gc(true)

Base.get_world_counter() # I get: 0x00000000000068e8
ggg(x) = sum(x) # define arbitrary function
Base.get_world_counter() # I get: 0x00000000000068e9

X = gplvm_gradient_free(Y; iterations = 10) # observe ~7% memory usage

I then contrast to the following. I start a new Julia session and execute the following:

Y = randn(MersenneTwister(1234), 12, 1000)
X = gplvm_gradient_free(Y; iterations = 1)  # warmup
X = gplvm_gradient_free(Y; iterations = 10)  # I observe ~3% memory usage in top
X = gplvm_gradient_free(Y; iterations = 10)  # Run again, I observe ~3% memory usage in top

The memory consumption is less in the second session, when no calls to either Mooncake or DifferentiationInterface are made.