Zygote.gradient does not work with AbstractGPs.CustomMean

Hello, I am trying to use an AbstractGPs.FiniteGP with Turing and I am having problems with Zygote.gradient calculation of the GP hyperparameters.

For some reason, Zygote fails to calculate the gradient when using an AbstractGPs.CustomMean but works properly when using AbstractGPs.ZeroMean.

When I try using ForwardDiff instead, both gradients are calculated without an issue.

MWE of the issue:

using AbstractGPs
using Zygote
using ForwardDiff

X = [1.; 2.; 3.;;]
y = [1., 2., 3.]

function construct_finite_gp(X, lengthscale, noise; min_param_val=1e-6, mean)
    # for numerical stability
    lengthscale = lengthscale + min_param_val
    noise = noise + min_param_val

    kernel = with_lengthscale(Matern52Kernel(), lengthscale)
    return GP(mean, kernel)(X', noise)
end

function f(noise; mean)
    gp = construct_finite_gp(X, 1., first(noise); mean)
    return logpdf(gp, y)
end

forwarddiff_zeromean = ForwardDiff.gradient(n -> f(n; mean=AbstractGPs.ZeroMean()), [1.])
@show forwarddiff_zeromean

forwarddiff_custommean = ForwardDiff.gradient(n -> f(n; mean=AbstractGPs.CustomMean(x->0.)), [1.])
@show forwarddiff_custommean

zygote_zeromean = Zygote.gradient(n -> f(n; mean=AbstractGPs.ZeroMean()), 1.)
@show zygote_zeromean

zygote_custommean = Zygote.gradient(n -> f(n; mean=AbstractGPs.CustomMean(x->0.)), 1.)
@show zygote_custommean

The result of running the MWE:

forwarddiff_zeromean = [0.2630913490608584]
forwarddiff_custommean = [0.2630913490608584]
zygote_zeromean = (0.26309134906085807,)
ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{Matrix{Float64}}}, ::Vector{Nothing})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::Distributions.MvNormal, ::AbstractVector) at C:\Users\sheld\.julia\packages\Distributions\7iOJp\src\multivariate\mvnormal.jl:294
  +(::SparseArrays.AbstractSparseMatrixCSC, ::Array) at C:\Users\sheld\AppData\Local\Programs\Julia-1.8.0\share\julia\stdlib\v1.8\SparseArrays\src\sparsematrix.jl:1832
  ...
Stacktrace:
  [1] accum(x::NamedTuple{(:X,), Tuple{Matrix{Float64}}}, y::Vector{Nothing}) (repeats 2 times)
    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\lib\lib.jl:17
  [2] (::typeof(∂(mean_and_cov)))(Δ::Tuple{Vector{Float64}, LinearAlgebra.UpperTriangular{Float64, Matrix{Float64}}})    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\compiler\interface2.jl:0
  [3] Pullback
    @ C:\Users\sheld\.julia\packages\AbstractGPs\e2ey8\src\finite_gp_projection.jl:134 [inlined]
  [4] (::typeof(∂(mean_and_cov)))(Δ::Tuple{Vector{Float64}, LinearAlgebra.UpperTriangular{Float64, Matrix{Float64}}})    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\compiler\interface2.jl:0
  [5] Pullback
    @ C:\Users\sheld\.julia\packages\AbstractGPs\e2ey8\src\finite_gp_projection.jl:307 [inlined]
  [6] (::typeof(∂(logpdf)))(Δ::Float64)
    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\compiler\interface2.jl:0
  [7] Pullback
    @ d:\plzen\sandbox\sampling_error_mwe\zygote_mwe.jl:19 [inlined]
  [8] (::typeof(∂(#f#47)))(Δ::Float64)
    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\compiler\interface2.jl:0
  [9] Pullback
    @ d:\plzen\sandbox\sampling_error_mwe\zygote_mwe.jl:17 [inlined]
 [10] (::typeof(∂(f##kw)))(Δ::Float64)
    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\compiler\interface2.jl:0
 [11] Pullback
    @ d:\plzen\sandbox\sampling_error_mwe\zygote_mwe.jl:31 [inlined]
 [12] (::typeof(∂(#56)))(Δ::Float64)
    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\compiler\interface2.jl:0
 [13] (::Zygote.var"#60#61"{typeof(∂(#56))})(Δ::Float64)
    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\compiler\interface.jl:45
 [14] gradient(f::Function, args::Float64)
    @ Zygote C:\Users\sheld\.julia\packages\Zygote\dABKa\src\compiler\interface.jl:97
 [15] top-level scope
    @ d:\plzen\sandbox\sampling_error_mwe\zygote_mwe.jl:31

I need to use Zygote because of this issue.

Does anyone know what might be the cause of this?
Is it a bug in Zygote or AbstractGPs or some error on my side?

Hi @soldasim . This sounds like a frustrating issue.

Would I be correct in assuming that you’re not actually interested in using the custom mean function CustomMean(x -> 0.)? Would you mind providing the one that you’re actually interested in?

Hi @willtebbutt,

in my model I am using a parametric model as the mean of the GP.
I want the code to be general, so the parametric model is to be defined by the user, so I need it to work with any real function.

An example could be:

function param_model(x, params)
    return params[1] * cos(params[2] * x[1]) + params[3]
end

zygote_custommean = Zygote.gradient(n -> f(n; mean=AbstractGPs.CustomMean(x -> param_model(x, [1.,2.,3.]))), 1.)

EDIT: error in the example

Okay, thanks for the additional info.

One more question: is the actual value of X that you plan to use going to be multi-dimensional, or do you only ever do 1D stuff?

X can be multidimensional.

Great.

In that case I would suggest implementing a new mean function, so that you can implement _map_meanfunction for it, and get the appropriate cotangent type when you work with ColVecs and RowVecs.

In particular, I would do something like

struct ParamModel{Tparams} <: AbstractGPs.MeanFunction
    params::Tparams
end

function AbstractGPs._map_meanfunction(m::ParamModel, x::ColVecs)
    X = x.X
    # whatever computation needs to happen on `X` using broadcasting etc
end


function AbstractGPs._map_meanfunction(m::ParamModel, x::RowVecs)
    X = x.X
    # whatever computation needs to happen on `X` using broadcasting etc
end

What’s gong wrong with AD at the minute is that the ChainRulesCore.rrule for map doesn’t do the right thing for ColVecs and RowVecs, which is what _map_meanfunction hits for CustomMean. By implementing _map_meanfunction on your own type, you can ensure that you never hit this code path.

It’s an annoying work-around to have to make, but that’s (unfortunately) where we are with AD at the minute.

2 Likes

Thank you so much! It would have been difficult for me to find the cause as I’m not that experienced in julia yet.

I’ve generalized the solution into

struct MyCustomMean{Tf} <: AbstractGPs.MeanFunction
    f::Tf
end

function AbstractGPs._map_meanfunction(m::MyCustomMean, x::ColVecs)
    X = x.X
    map(m.f, eachcol(X))
end

function AbstractGPs._map_meanfunction(m::MyCustomMean, x::RowVecs)
    X = x.X
    map(m.f, eachcol(X))
end

so that MyCustomMean can be used exactly like AbstractGPs.CustomMean.

The code seems to work now:

julia> ForwardDiff.gradient(n -> f(n; mean=AbstractGPs.CustomMean(x->0.)), [1.])
1-element Vector{Float64}:
 0.2630913490608584

julia> ForwardDiff.gradient(n -> f(n; mean=MyCustomMean(x->0.)), [1.])
1-element Vector{Float64}:
 0.2630913490608584

julia> Zygote.gradient(n -> f(n; mean=MyCustomMean(x->0.)), 1.)
(0.26309134906085807,)

Thank you for your time and help :slight_smile:

1 Like

Hi, could you please let me know where I am doing this wrong? or something has changed.

using AbstractGPs
using Zygote
# using ForwardDiff

X = [1.; 2.; 3.;;]
y = [1., 2., 3.]

function construct_finite_gp(X, lengthscale, noise; min_param_val=1e-6, mean)
    # for numerical stability
    lengthscale = lengthscale + min_param_val
    noise = noise + min_param_val

    kernel = with_lengthscale(Matern52Kernel(), lengthscale)
    return GP(mean, kernel)(X', noise)
end

function f(noise; mean)
    gp = construct_finite_gp(X, 1., first(noise); mean)
    return logpdf(gp, y)
end


struct MyCustomMean{Tf} <: AbstractGPs.MeanFunction
    f::Tf
end

function AbstractGPs._map_meanfunction(m::MyCustomMean, x::ColVecs)
    X = x.X
    map(m.f, eachcol(X))
end

function AbstractGPs._map_meanfunction(m::MyCustomMean, x::RowVecs)
    X = x.X
    map(m.f, eachcol(X))
end


# zygote_zeromean = Zygote.gradient(n -> f(n; mean=AbstractGPs.ZeroMean()), 1.)
# @show zygote_zeromean

Zygote.gradient(n -> f(n; mean=MyCustomMean(x->0.)), 1.)

I am not sure what changed but this is how I currently do it.

You don’t need any MyCustomMean structure, just pass the custom mean function as a function. The ForwardDiff works as is, and you need these two methods as a workaround for zygote to work;

AbstractGPs.mean_vector(m::AbstractGPs.CustomMean, x::ColVecs) = map(m.f, eachcol(x.X))
AbstractGPs.mean_vector(m::AbstractGPs.CustomMean, x::RowVecs) = map(m.f, eachrow(x.X))

The complete fixed script can look like this

using AbstractGPs
using Zygote
using ForwardDiff

X = [1.;; 2.;; 3.;;]
y = [1., 2., 3.]

function construct_finite_gp(X, lengthscale, noise; min_param_val=1e-6, mean)
    # for numerical stability
    lengthscale = lengthscale + min_param_val
    noise = noise + min_param_val

    kernel = with_lengthscale(Matern52Kernel(), lengthscale)
    return GP(mean, kernel)(X, noise)
end

function f(noise; mean)
    gp = construct_finite_gp(X, 1., first(noise); mean)
    return logpdf(gp, y)
end

# workaround needed for Zygote
AbstractGPs.mean_vector(m::AbstractGPs.CustomMean, x::ColVecs) = map(m.f, eachcol(x.X))
AbstractGPs.mean_vector(m::AbstractGPs.CustomMean, x::RowVecs) = map(m.f, eachrow(x.X))

zygote = Zygote.gradient(n -> f(n; mean=x->0.), 1.)
fwddiff = ForwardDiff.gradient(n -> f(n; mean=x->0.), [1.])

@show zygote
@show fwddiff

Thanks a lot @soldasim!
This worked for me.
Unfortunately, AbstractGPs doesn’t adapt with Zygote quickly. I think they need to vectorize the mean operation.

1 Like