[Zygote] gradient throws error when mean function is used

Hello,

I am quite a Julia newbe and currently migrating from Python/JAX ecosystem into Julia and I have encounter frustrating error while computing a gradient of quite a simple function.

The problem here is the function mean which causes the error. If you change mean to for example norm., then the gradient will be computed without any problem.

Could someone point me why function mean generate such error and how can I solve it?

Minimum working example

using Statistics
using LinearAlgebra
using Zygote

D = 2
N = [rand(D) for i ∈ 1:5]
S = [rand(D) for i ∈ 1:5]
		
function cost(N, S)
	sum(mean(N.-S))
end
cost(N, S)
βˆ‡ = gradient(cost, N, S)
MethodError: no method matching zero(::Type{Vector{Float64}})

Closest candidates are:

zero(!Matched::Union{Type{P}, P}) where P<:Dates.Period at /build/julia/src/julia-1.6.1/usr/share/julia/stdlib/v1.6/Dates/src/periods.jl:53

zero(!Matched::CartesianIndex{N}) where N at multidimensional.jl:106

zero(!Matched::LinearAlgebra.UniformScaling{T}) where T at /build/julia/src/julia-1.6.1/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/uniformscaling.jl:136

...

zero(::Vector{Vector{Float64}})@abstractarray.jl:1085
_backmean(::Vector{Vector{Float64}}, ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, ::Colon)@array.jl:327
(::Zygote.var"#656#657"{Colon, Vector{Vector{Float64}}})(::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})@array.jl:325
(::Zygote.var"#2769#back#658"{Zygote.var"#656#657"{Colon, Vector{Vector{Float64}}}})(::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})@adjoint.jl:59
Pullback@Local: 7[inlined]
(::typeof(βˆ‚(cost)))(::Float64)@interface2.jl:0
(::Zygote.var"#41#42"{typeof(βˆ‚(cost))})(::Float64)@interface.jl:41
gradient(::Function, ::Vector{Vector{Float64}}, ::Vararg{Vector{Vector{Float64}}, N} where N)@interface.jl:59
top-level scope@Local: 10

I think the problem is that N and S are vectors of vectors instead of matrices. Try
N = rand(D, 5); S = rand(D, 5) and cost(N, S) = sum(mean(N .- S, dims=2)) instead. (And then notice that the gradient is independent of the input data, since it is just a (weighted) sum of it)

As an explanation: The problem is that julia/Zygote does not know what the neutral (zero) element for addition of a vector of vectors is (which it needs to compute the gradient of the mean), but it knows that the neutral element for the addition of matrices is the zero-matrix.

2 Likes

It does seem like Zygote should be able to do this, might be worth filing a bug. Until then, you can define a rule for the operation on which the existing generic rule fails, which is mean(::Vector{<:Vector}). Defining something like the following makes your existing code work as-is:

Zygote.@adjoint function mean(A::Vector{Vector{T}}) where {T}
    mean(A), Ξ” -> ([similar(a) .= Ξ” ./ length(A) for a in A],)
end

(Fwiw performance-wise I’d guess the 2D matrix solution is better, but perhaps that’s not an issue here. Also, if it matters, its probably possible to write a much more performant version of what I wrote above.)

1 Like