Differentiation (Zygote, but the issue is likely in ChainRules) return different types with `Diagonal`

Dear All,

We are hacking a bit AbstractGP.jl and we have found an inconsistency in the returned type of Diagonal, which causes Zygote to complain.

An MWE adapted from GPs is

using LinearAlgebra, Zygote

K = [1.0009999993448562 0.24527300257806056 0.3032821627747173 0.6914136223631672 0.3746318512475686 0.25806017362715294 0.2078077842767286 0.4496418717665137 0.3664088915838278 0.24069668882897371; 0.24527300257806056 1.0009999993448562 0.2801290863077014 0.20891979354686166 0.26106848570002494 0.5040269983590268 0.3574840458389145 0.3297751063879473 0.25069027539785393 0.5163648478175171; 0.3032821627747173 0.2801290863077014 1.0009999993448562 0.5269272168279036 0.31515192873951936 0.244109237018002 0.1916241919637497 0.4921661376941947 0.3126054947854366 0.2595934459178986; 0.6914136223631672 0.20891979354686166 0.5269272168279036 1.0009999993448562 0.29792724577200785 0.22003770001639386 0.17936870117256806 0.3644702409595687 0.30401210750076535 0.20455516195564313; 0.3746318512475686 0.26106848570002494 0.31515192873951936 0.29792724577200785 1.0009999993448562 0.2827677397937746 0.23783913557120728 0.5411063090341129 0.41259587506668016 0.2703822271373879; 0.25806017362715294 0.5040269983590268 0.244109237018002 0.22003770001639386 0.2827677397937746 1.0009999993448562 0.489697495653453 0.32013626566362274 0.2672905224451533 0.44659724020098884; 0.2078077842767286 0.3574840458389145 0.1916241919637497 0.17936870117256806 0.23783913557120728 0.489697495653453 1.0009999993448562 0.30022124481837253 0.22292563762388867 0.4061051235451615; 0.4496418717665137 0.3297751063879473 0.4921661376941947 0.3644702409595687 0.5411063090341129 0.32013626566362274 0.30022124481837253 1.0009999993448562 0.4921756989641169 0.3529789127923668; 0.3664088915838278 0.25069027539785393 0.3126054947854366 0.30401210750076535 0.41259587506668016 0.2672905224451533 0.22292563762388867 0.4921756989641169 1.0009999993448562 0.3083105416251713; 0.24069668882897371 0.5163648478175171 0.2595934459178986 0.20455516195564313 0.2703822271373879 0.44659724020098884 0.4061051235451615 0.3529789127923668 0.3083105416251713 1.0009999993448562]
d_ll = [-0.4789038614343079, 0.3526011634298183, -0.4788187573129515, 0.4800358657914971, -0.44917358411354336, 0.34923344477230933, 0.35035293763581776, 0.45598676261552384, -0.4459245159060682, 0.3551514395341081]
d2_ll = [-0.24955495293761715, -0.22827358297775682, -0.24955135495823239, -0.24960143334530488, -0.24741667544813697, -0.2272694458247757, -0.22760575672577055, -0.24806283493493775, -0.24707584202000693, -0.22901889453095886]
f = [-0.08443468127896654, 0.6076249748777804, -0.08477570730728931, 0.0798990150113338, -0.20401031020993757, 0.6224103889912268, 0.6174881948022232, 0.1765098002461307, -0.21715123949401738, 0.5964712105043362]



gradient(d2_ll) do d2_ll
    W = -Diagonal(d2_ll)
    Wsqrt = sqrt(W)
    B = I + Wsqrt * K * Wsqrt
    B_ch = cholesky(Symmetric(B))
    b = W * f + d_ll
    a = b - Wsqrt * (B_ch \ (Wsqrt * K * b))
    sum(a)
end

which errors with ERROR: MethodError: no method matching +(::Diagonal{Float64, Vector{Float64}}, ::@NamedTuple{diag::Vector{Float64}}).

I think the culprit is in inconsistency in returned type of following two gradients

W = -Diagonal(d2_ll)
Wsqrt = sqrt(W)
B = I + Wsqrt * K * Wsqrt
B_ch = cholesky(Symmetric(B))
b = W * f + d_ll

gradient(Wsqrt -> sum(Wsqrt * (B_ch \ (Wsqrt * K * b))), Wsqrt)[1]
gradient(Wsqrt -> sum(Wsqrt * K * Wsqrt), Wsqrt)[1]

Since the first one returns Diagonal and the second returns NamedTuple{:diagonal}. I am not sure, which one should be the correct version. I can imagine both of them.

A temporal monkeypatch is defining

import Base.+
+(a::Diagonal, b::@NamedTuple{diag::Vector{Float64}}) = Diagonal(a.diag + b.diag)

but this does not solve the problem in ChainRules.

This is a mismatch between the natural and structural gradient representations. What should happen is that ChainRules’s projection mechanism should standardise on the natural one, like this:

julia> using ChainRulesCore

julia> const ArrayOrZero = Union{AbstractArray, AbstractZero};

julia> function (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal}) # structural => natural
            return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx
        end

julia> gradient(Wsqrt -> sum(Wsqrt * K * Wsqrt), Wsqrt)[1]
10Γ—10 Diagonal{Float64, Vector{Float64}}:
 4.08791   β‹…        β‹…       β‹…        β‹…        β‹…        β‹…        β‹…        β‹…        β‹… 
  β‹…       3.84352   β‹…       β‹…        β‹…        β‹…        β‹…        β‹…        β‹…        β‹… 
1 Like

This makes sense and it in the area I was expecting. I would like to find the culprit and fix it for others. So my question is if it is a wrong rrule, which is missing ProjectTo mechanism in the rule, or missing function project which @mcabbott has written above? I think it make sense to create a pull request fixing this.