Dear All,
We are hacking a bit AbstractGP.jl and we have found an inconsistency in the returned type of Diagonal
, which causes Zygote
to complain.
An MWE adapted from GPs is
using LinearAlgebra, Zygote
K = [1.0009999993448562 0.24527300257806056 0.3032821627747173 0.6914136223631672 0.3746318512475686 0.25806017362715294 0.2078077842767286 0.4496418717665137 0.3664088915838278 0.24069668882897371; 0.24527300257806056 1.0009999993448562 0.2801290863077014 0.20891979354686166 0.26106848570002494 0.5040269983590268 0.3574840458389145 0.3297751063879473 0.25069027539785393 0.5163648478175171; 0.3032821627747173 0.2801290863077014 1.0009999993448562 0.5269272168279036 0.31515192873951936 0.244109237018002 0.1916241919637497 0.4921661376941947 0.3126054947854366 0.2595934459178986; 0.6914136223631672 0.20891979354686166 0.5269272168279036 1.0009999993448562 0.29792724577200785 0.22003770001639386 0.17936870117256806 0.3644702409595687 0.30401210750076535 0.20455516195564313; 0.3746318512475686 0.26106848570002494 0.31515192873951936 0.29792724577200785 1.0009999993448562 0.2827677397937746 0.23783913557120728 0.5411063090341129 0.41259587506668016 0.2703822271373879; 0.25806017362715294 0.5040269983590268 0.244109237018002 0.22003770001639386 0.2827677397937746 1.0009999993448562 0.489697495653453 0.32013626566362274 0.2672905224451533 0.44659724020098884; 0.2078077842767286 0.3574840458389145 0.1916241919637497 0.17936870117256806 0.23783913557120728 0.489697495653453 1.0009999993448562 0.30022124481837253 0.22292563762388867 0.4061051235451615; 0.4496418717665137 0.3297751063879473 0.4921661376941947 0.3644702409595687 0.5411063090341129 0.32013626566362274 0.30022124481837253 1.0009999993448562 0.4921756989641169 0.3529789127923668; 0.3664088915838278 0.25069027539785393 0.3126054947854366 0.30401210750076535 0.41259587506668016 0.2672905224451533 0.22292563762388867 0.4921756989641169 1.0009999993448562 0.3083105416251713; 0.24069668882897371 0.5163648478175171 0.2595934459178986 0.20455516195564313 0.2703822271373879 0.44659724020098884 0.4061051235451615 0.3529789127923668 0.3083105416251713 1.0009999993448562]
d_ll = [-0.4789038614343079, 0.3526011634298183, -0.4788187573129515, 0.4800358657914971, -0.44917358411354336, 0.34923344477230933, 0.35035293763581776, 0.45598676261552384, -0.4459245159060682, 0.3551514395341081]
d2_ll = [-0.24955495293761715, -0.22827358297775682, -0.24955135495823239, -0.24960143334530488, -0.24741667544813697, -0.2272694458247757, -0.22760575672577055, -0.24806283493493775, -0.24707584202000693, -0.22901889453095886]
f = [-0.08443468127896654, 0.6076249748777804, -0.08477570730728931, 0.0798990150113338, -0.20401031020993757, 0.6224103889912268, 0.6174881948022232, 0.1765098002461307, -0.21715123949401738, 0.5964712105043362]
gradient(d2_ll) do d2_ll
W = -Diagonal(d2_ll)
Wsqrt = sqrt(W)
B = I + Wsqrt * K * Wsqrt
B_ch = cholesky(Symmetric(B))
b = W * f + d_ll
a = b - Wsqrt * (B_ch \ (Wsqrt * K * b))
sum(a)
end
which errors with ERROR: MethodError: no method matching +(::Diagonal{Float64, Vector{Float64}}, ::@NamedTuple{diag::Vector{Float64}})
.
I think the culprit is in inconsistency in returned type of following two gradients
W = -Diagonal(d2_ll)
Wsqrt = sqrt(W)
B = I + Wsqrt * K * Wsqrt
B_ch = cholesky(Symmetric(B))
b = W * f + d_ll
gradient(Wsqrt -> sum(Wsqrt * (B_ch \ (Wsqrt * K * b))), Wsqrt)[1]
gradient(Wsqrt -> sum(Wsqrt * K * Wsqrt), Wsqrt)[1]
Since the first one returns Diagonal
and the second returns NamedTuple{:diagonal}
. I am not sure, which one should be the correct version. I can imagine both of them.
A temporal monkeypatch is defining
import Base.+
+(a::Diagonal, b::@NamedTuple{diag::Vector{Float64}}) = Diagonal(a.diag + b.diag)
but this does not solve the problem in ChainRules.