Zygote.gradient does not work with AbstractGPs.CustomMean

Great.

In that case I would suggest implementing a new mean function, so that you can implement _map_meanfunction for it, and get the appropriate cotangent type when you work with ColVecs and RowVecs.

In particular, I would do something like

struct ParamModel{Tparams} <: AbstractGPs.MeanFunction
    params::Tparams
end

function AbstractGPs._map_meanfunction(m::ParamModel, x::ColVecs)
    X = x.X
    # whatever computation needs to happen on `X` using broadcasting etc
end


function AbstractGPs._map_meanfunction(m::ParamModel, x::RowVecs)
    X = x.X
    # whatever computation needs to happen on `X` using broadcasting etc
end

What’s gong wrong with AD at the minute is that the ChainRulesCore.rrule for map doesn’t do the right thing for ColVecs and RowVecs, which is what _map_meanfunction hits for CustomMean. By implementing _map_meanfunction on your own type, you can ensure that you never hit this code path.

It’s an annoying work-around to have to make, but that’s (unfortunately) where we are with AD at the minute.

2 Likes