Great.
In that case I would suggest implementing a new mean function, so that you can implement _map_meanfunction for it, and get the appropriate cotangent type when you work with ColVecs and RowVecs.
In particular, I would do something like
struct ParamModel{Tparams} <: AbstractGPs.MeanFunction
params::Tparams
end
function AbstractGPs._map_meanfunction(m::ParamModel, x::ColVecs)
X = x.X
# whatever computation needs to happen on `X` using broadcasting etc
end
function AbstractGPs._map_meanfunction(m::ParamModel, x::RowVecs)
X = x.X
# whatever computation needs to happen on `X` using broadcasting etc
end
What’s gong wrong with AD at the minute is that the ChainRulesCore.rrule for map doesn’t do the right thing for ColVecs and RowVecs, which is what _map_meanfunction hits for CustomMean. By implementing _map_meanfunction on your own type, you can ensure that you never hit this code path.
It’s an annoying work-around to have to make, but that’s (unfortunately) where we are with AD at the minute.