I just tried out the following code based on code provided on the AbstractGP.jl repo. One can run it by copy-paste provided the package necessary are available:
using AbstractGPs, DifferentiationInterface, Optim, StatsFuns, Random
import Mooncake
import Zygote
let
rng = MersenneTwister(1)
x = randn(rng, 10_000)
y = randn(rng, 10_000)
f = GP(Matern52Kernel())
noise_var = 0.1
fx = f(x, noise_var)
function loss_function(x, y)
function negativelogmarginallikelihood(params)
kernel =
softplus(params[1]) * (Matern52Kernel() ∘ ScaleTransform(softplus(params[2])))
f = GP(kernel)
fx = f(x, noise_var)
return -logpdf(fx, y)
end
return negativelogmarginallikelihood
end
θ0 = randn(rng, 2)
# comment in to use mooncake - runs out of memory
opt = Optim.optimize(loss_function(x[1:4],y[1:4]), θ0, LBFGS(), autodiff=AutoMooncake(config=nothing)) # warmup
opt = Optim.optimize(loss_function(x,y), θ0, LBFGS(), autodiff=AutoMooncake(config=nothing))
# comment in to use Zygote - runs out of memory
#opt = Optim.optimize(loss_function(x[1:4],y[1:4]), θ0, LBFGS(), autodiff=AutoZygote()) # warmup
#opt = Optim.optimize(loss_function(x,y), θ0, LBFGS(), autodiff=AutoZygote())
end
For the above example of 10_000 data items, my 32GB machine runs out of memory and Julia is terminated.