Cannot take gradient of a function involves `eachrow` with Zygote?

The following code

using Zygote
gradient(x -> sum(map(sum, eachcol(x))), rand(3, 4))

gives the error message

ERROR: Need an adjoint for constructor Base.Generator{Base.OneTo{Int64}, Base.var"#193#194"{Matrix{Float64}}}. Gradient is of type Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}

However if I add a collect function to the results of eachcol, I can get the correct gradient.

gradient(x -> sum(map(sum, collect(eachcol(x)))), rand(3, 4))

Is this a limitation of Zygote, or did I use it wrong? Thanks!

1 Like

You’re code is fine.
That error message means that
Zygote is missing the rule for that constructor.

3 Likes

Thanks! I see someone has made a PR related to this issue (https://github.com/FluxML/Zygote.jl/pull/896), if add the adjoints in this PR, my original code works…

4 Likes

Thanks for this tip, it helped me work around this issue.