Good practices regarding automatic differentiation in a Gaussian process implementation

(Related to Large memory consumption when using Mooncake via DifferentiationInterface for Gaussian process optimisation)

I’m starting this thread to discuss good practices for using Mooncake.jl via DifferentiationInterface.jl together with Optim.jl when implementing Gaussian processes (GPs). I use GPs in my research, but I often run into numerical issues. In particular, I work with the GPLVM model and often I need to optimise thousands of parameters. Some of these issues I face may be unrelated to Mooncake.jl or DifferentiationInterface.jl and instead stem from implementation details e.g. type-instability or other oversights. My goal is to document these problems and their resolutions in a way that others can reuse by making them available on an online resource (perhaps here). To do this, I will share a sequence of code examples, where each iteration isolates and fixes a specific issue.

Although the Julia ecosystem offers many relevant packages, I want to keep the examples “self-contained” with as few dependencies as possible, so that the underlying issues are easy to identify.

Below I start with what I consider straightforward code, the kind a casual Julia user (including myself) might write. I begin with standard GP regression with one-dimensional inputs and outputs. I plan to address issues systematically, either within this thread or by opening a separate thread per issue (I’d appreciate your guidance on which approach is preferable).

The first issue that I want to address concerns a certain numerical problem that I have run into multiple times. Though the code may be inefficient in many respects, the focus is on isolating and hopefully resolving a very specific issue.


Issue 1: The code below implements standard GP regression for one-dimensional inputs and one-dimensional outputs using the RBF kernel. Function fitgp_1 implements two versions of the objective function to be maximised. The first version is called marginal_loglikelihood_MvNormal and uses Distributions.MvNormal to calculate the log-likelihood. The second version is called marginal_loglikelihood_explicit and calculates the log-likelihood explicitly. Both versions agree in that they calculate the same log-likelihood when given the same hyperparameters. However, when optimising using automatic gradients I get wildly different results. In particular, the version marginal_loglikelihood_explicit seems to suffer from numerical issues. It is not obvious why there is such a discrepancy between the two versions.

Why does the second version behave so differently?

using Distributions
using DifferentiationInterface
using LinearAlgebra
using Optim
using Random
import Mooncake

function toydata()
    # generate some synthetic data, not important how exactly,
    # just to have something to test the code on
    rng = MersenneTwister(1234)
    x = rand(rng, 1000)*10
    y = sin.(x) + 0.1*randn(rng, length(x))
    return y, x
end

function fitgp_1(t, x; iterations = 10, JITTER = 1e-6)

    # t are the target outputs (here one-dimensional)
    # x are the inputs (here one-dimensional)

    # Initialise hyperparameters
    logℓ, logα, logσ = 0.0, 0.0, 0.0
   
    rbf(x, y, logℓ, logα) = exp(-abs2(x - y)*exp(logℓ))*exp(logα)

    function calculatecovariance!(K, logℓ, logα)
        # Can be done better, but point here it to understand where
        # problems may arise with allocations and performance
        # when writing own code

        local N = length(x)

        for i in 1:N
            for j in 1:N
                K[i, j] = rbf(x[i], x[j], logℓ, logα)
            end
        end

    end

    function unpack_parameters(params)
        local logℓ, logα, logσ = params
        return logℓ, logα, logσ
    end

    function marginal_loglikelihood_MvNormal(logℓ, logα, logσ)

        # calculate the covariance matrix for the passed hyperparameters
        local K = zeros(length(x), length(x))
        calculatecovariance!(K, logℓ, logα)

        # add noise variance to the diagonal and jitter for numerical stability   
        K += exp(logσ) * I + JITTER * I 

        # return log marginal likelihood of a Gaussian process for Gaussian likelihood
        return logpdf(MvNormal(zeros(length(x)), K), t)

    end


    function marginal_loglikelihood_explicit(logℓ, logα, logσ)

        # calculate the covariance matrix for the passed hyperparameters
        local K = zeros(length(x), length(x))
        calculatecovariance!(K, logℓ, logα)

        # add noise variance to the diagonal and jitter for numerical stability   
        K += exp(logσ) * I + JITTER * I 

        # return log marginal likelihood of a Gaussian process for Gaussian likelihood
        local L = cholesky(Symmetric(K)).L
        return -0.5 * sum(abs2, L \ t) - sum(log, diag(L)) - 0.5*length(x)*log(2π)
    end

    # define negative log-likelihood for optimization
    nll_MvNormal(params) = -marginal_loglikelihood_MvNormal(unpack_parameters(params)...)
    nll_explicit(params) = -marginal_loglikelihood_explicit(unpack_parameters(params)...)

    # Sanity check:
    # show that the two implementations of the marginal log-likelihood 
    # differ only very slightly for the same hyperparameters.
    let
        initparam = randn(3)
        @show nll_MvNormal(initparam)
        @show nll_explicit(initparam)
    end

    # Optimise hyperparameters using LBFGS
    opt = Optim.Options(iterations = iterations, show_trace = true, show_every = 1)

    # optimise using the MvNormal implementation
    optimize(nll_MvNormal, [logℓ, logα, logσ], LBFGS(), opt, autodiff=Mooncake.AutoMooncake())

    # optimise using the explicit implementation
    optimize(nll_explicit, [logℓ, logα, logσ], LBFGS(), opt, autodiff=Mooncake.AutoMooncake())
    
end

Provided that packages are locally available, you can copy-paste the above code and run it with:

y,x = toydata()
fitgp_1(y,x)

What you should see when running the above code is that the second optimisation run that uses marginal_loglikelihood_explicit runs into numerical issues and breaks prematurely.

My local setup is Julia Version 1.12.4:

(MooncakeGPExperiments) pkg> st
Status `~/tmp/MooncakeGPExperiments/Project.toml`
  [a0c0ee7d] DifferentiationInterface v0.7.16
  [31c24e10] Distributions v0.25.123
  [da2b9cff] Mooncake v0.5.12
  [429524aa] Optim v2.0.1
  [37e2e46d] LinearAlgebra v1.12.0
  [9a3f8284] Random v1.11.0

Regarding “running into numerical issues and breaking prematurely”, I’m getting the following output:

julia> includet("test-gp.jl")
nll_MvNormal(initparam) = 1256.6934078230022
nll_explicit(initparam) = 1256.6934078230022
[ Info: optimise using the MvNormal implementation
Iter     Function value   Gradient norm 
     0     9.535819e+02     4.870311e+02
 * time: 0.008840084075927734
     1    -7.815952e+02     6.417509e+01
 * time: 1.4313230514526367
     2    -7.871668e+02     2.077147e+01
 * time: 1.724761962890625
     3    -7.906472e+02     4.496643e+01
 * time: 2.114284038543701
     4    -8.066678e+02     4.151909e+01
 * time: 2.6622111797332764
     5    -8.095018e+02     1.087353e+01
 * time: 2.8410961627960205
     6    -8.098022e+02     5.253655e+00
 * time: 3.1320290565490723
     7    -8.099309e+02     7.471592e+00
 * time: 3.526421070098877
     8    -8.108028e+02     4.509602e+00
 * time: 3.854161024093628
     9    -8.111302e+02     4.800723e-01
 * time: 4.141431093215942
    10    -8.111457e+02     2.406922e-01
 * time: 4.5342631340026855
optimize(nll_MvNormal, [logℓ, logα, logσ], LBFGS(), opt, autodiff = Mooncake.AutoMooncake()) =  * Status: failure (reached maximum number of iterations)

 * Candidate solution
    Final objective value:     -8.111457e+02

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.12e-01 ≰ 0.0e+00
    |x - x'|/|x'|          = 2.47e-02 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.55e-02 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.91e-05 ≰ 0.0e+00
    |g(x)|                 = 2.41e-01 ≰ 1.0e-08

 * Work counters
    Seconds run:   5  (vs limit Inf)
    Iterations:    10
    f(x) calls:    35
    ∇f(x) calls:   35
    ∇f(x)ᵀv calls: 0

[ Info: optimise using the explicit implementation
Iter     Function value   Gradient norm 
     0     9.535819e+02     4.870311e+02
 * time: 9.608268737792969e-5
     1     7.082533e+02     1.087999e+07
 * time: 0.30930209159851074
     2     7.082540e+02     1.065883e+29
 * time: 6.451901912689209
┌ Warning: Failed to achieve finite new evaluation point, using alpha=0
└ @ LineSearches ~/.julia/packages/LineSearches/lihz0/src/hagerzhang.jl:156
     3     7.082540e+02              NaN
 * time: 11.357081890106201
┌ Warning: Terminated early due to NaN in gradient.
└ @ Optim ~/.julia/packages/Optim/DtV5C/src/multivariate/optimize/optimize.jl:136
optimize(nll_explicit, [logℓ, logα, logσ], LBFGS(), opt, autodiff = Mooncake.AutoMooncake()) =  * Status: failure

 * Candidate solution
    Final objective value:     7.082540e+02

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 0.00e+00 ≤ 0.0e+00
    |x - x'|/|x'|          = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|         = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00
    |g(x)|                 = NaN ≰ 1.0e-08

 * Work counters
    Seconds run:   11  (vs limit Inf)
    Iterations:    3
    f(x) calls:    113
    ∇f(x) calls:   113
    ∇f(x)ᵀv calls: 0

So indeed, the second implementation (nll_explicit) is experiencing numerical issues.

1 Like

This is exactly what I observe. I observe exactly the same numbers during the iterations of the optimiser and the premature break at exactly the same iteration. Thanks for confirming.

I found some parameter values that produce very different gradients.

nll_MvNormal(params::AbstractVector) = -marginal_loglikelihood_MvNormal(params...)
nll_explicit(params::AbstractVector) = -marginal_loglikelihood_explicit(params...)

grad_MvNormal(params::AbstractVector) = gradient(nll_MvNormal, Mooncake.AutoMooncake(), params)
grad_explicit(params::AbstractVector) = gradient(nll_explicit, Mooncake.AutoMooncake(), params)

@testset "nlls equal" begin
    for _ in 1:200
        initparam = 8randn(3)
        @test nll_MvNormal(initparam) ≈ nll_explicit(initparam)
    end
end

@testset "nll gradients equal" begin
    for i in 1:200
        initparam = 8randn(3)
        tbeg = time()
        the_grad_MvNormal = grad_MvNormal(initparam)
        time_MvNormal = time() - tbeg

        tbeg = time()
        the_grad_explicit = grad_explicit(initparam)
        time_explicit = time() - tbeg

        @printf("%d %s\tMvNormal:%.4f explicit:%.4f\n", i, string(initparam), time_MvNormal, time_explicit)
        @test the_grad_MvNormal ≈ the_grad_explicit
    end
end

Output:

julia> fitgp_1(y, x)
Test Summary: | Pass  Total  Time
nlls equal    |  200    200  6.4s
1 [-14.199736251937544, -9.10561555706572, 2.6538254797721703]	MvNormal:1.8051 explicit:1.5531
2 [5.609290053467907, -0.3885215223335294, 1.3378757003073125]	MvNormal:0.5607 explicit:0.5135
3 [-4.4435922197548035, 0.5150626168870941, 0.5421203043287387]	MvNormal:0.5207 explicit:0.3041
4 [5.731650553337068, -6.735325591925515, 0.3741385199880161]	MvNormal:0.4634 explicit:0.5015
5 [2.535604988715315, 3.54025643504518, -7.261611030713925]	MvNormal:0.5241 explicit:0.3096
6 [2.362412579435235, 4.949111093754604, 5.847255214438686]	MvNormal:0.4619 explicit:0.5019
7 [-3.3609508312745775, -7.704552624782568, 0.7951606709014786]	MvNormal:0.2993 explicit:0.4610
8 [3.0501851494084233, 14.735943518357502, -19.306326079864107]	MvNormal:0.5147 explicit:0.3082
nll gradients equal: Test Failed at /Users/forcebru/test/test-gp.jl:81
  Expression: the_grad_MvNormal ≈ the_grad_explicit
   Evaluated: [-977214.8128871938, -25446.0, -18059.25871024232] ≈ [-971367.1885299702, -28378.0, -18059.258710238602]
...
10 [-0.08113466534285438, 15.506494081130453, -8.708303885960373]	MvNormal:0.2992 explicit:0.4706
nll gradients equal: Test Failed at /Users/forcebru/test/test-gp.jl:81
  Expression: the_grad_MvNormal ≈ the_grad_explicit
   Evaluated: [-373.8878286961408, -2.610595703125, -30560.1728372343] ≈ [-373.6003559526871, -5.477783203125, -30560.17283722648]
...
21 [-1.8824827558699302, 12.481122633478888, -5.412806554614432]	MvNormal:0.4609 explicit:0.5116
nll gradients equal: Test Failed at /Users/forcebru/test/test-gp.jl:81
  Expression: the_grad_MvNormal ≈ the_grad_explicit
   Evaluated: [18.50852768868208, 6.390734702348709, -697.4361307374119] ≈ [18.508415449410677, 6.389977499842644, -697.4361307374121]
...
24 [-2.8114261619152034, 8.802134156724248, -9.63474023524947]	MvNormal:0.3042 explicit:0.4650
nll gradients equal: Test Failed at /Users/forcebru/test/test-gp.jl:81
  Expression: the_grad_MvNormal ≈ the_grad_explicit
   Evaluated: [-316.85580825805664, -20.739328384399414, -78753.90638871562] ≈ [-316.88093090057373, -20.64815902709961, -78753.90638871577]
...
123 [-12.490092744128612, 23.9155285615708, -4.605820340345509]	MvNormal:0.3069 explicit:0.4683
nll gradients equal: Test Failed at /Users/forcebru/test/test-gp.jl:81
  Expression: the_grad_MvNormal ≈ the_grad_explicit
   Evaluated: [-184.94972801208496, -857.453125, -11823.697753225504] ≈ [-184.9943962097168, -1208.28125, -11823.697753225508]
...
128 [-3.467596657790478, 9.362793887890204, -14.569677208674502]	MvNormal:0.3088 explicit:0.4747
nll gradients equal: Test Failed at /Users/forcebru/test/test-gp.jl:81
  Expression: the_grad_MvNormal ≈ the_grad_explicit
   Evaluated: [-13023.033203125, -1254.0859375, -1.164290547821325e6] ≈ [-12988.89453125, -815.5, -1.1642905478213246e6]
...
170 [-3.8382965672555174, 21.092804035316234, -6.730961097268818]	MvNormal:0.5143 explicit:0.3081
nll gradients equal: Test Failed at /Users/forcebru/test/test-gp.jl:81
  Expression: the_grad_MvNormal ≈ the_grad_explicit
   Evaluated: [27.40576171875, -14.9638671875, -3988.72208146421] ≈ [9.25537109375, 18.4873046875, -3988.722081464212]
...

In some cases (mostly when the second element of the vector is large and positive and the third is large and negative), the derivative wrt the second parameter can vary a lot and even change sign.

1 Like