Hi,
I’ve run into the problem described in Zygote bugs #857, #317 and maybe #495. In essence, when broadcasting over a piecewise function where some branches may not have an adjoint, Zygote will try to back-propagate nothings, which gives problems further up (I think. My understanding of Zygote is limited). In my case, it manifested as a trace-back like this:
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::LLVM.ModuleGlobalSet) at /opt/julia/julia-1.5.3-depot/packages/LLVM/7Q46C/src/core/module.jl:126
iterate(::LLVM.ModuleGlobalSet, ::Any) at /opt/julia/julia-1.5.3-depot/packages/LLVM/7Q46C/src/core/module.jl:126
iterate(::Base.AsyncGenerator, ::Base.AsyncGeneratorState) at asyncmap.jl:382
...
Stacktrace:
[1] _zip_iterate_some at ./iterators.jl:352 [inlined]
[2] _zip_iterate_some at ./iterators.jl:354 [inlined]
[3] _zip_iterate_all at ./iterators.jl:344 [inlined]
[4] iterate at ./iterators.jl:334 [inlined]
[5] iterate at ./generator.jl:44 [inlined]
[6] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof(∂(λ)),1},Nothing}},Base.var"#3#4"{Zygote.var"#510#514"}}) at ./array.jl:686
[7] map(::Function, ::Array{typeof(∂(λ)),1}, ::Nothing) at ./abstractarray.jl:2248
[8] (::Zygote.var"#509#513"{Array{typeof(∂(λ)),1}})(::Nothing) at /opt/julia/julia-1.5.3-depot/packages/Zygote/KpME9/src/lib/array.jl:187
[9] (::Zygote.var"#540#541"{Zygote.var"#509#513"{Array{typeof(∂(λ)),1}}})(::Nothing) at /opt/julia/julia-1.5.3-depot/packages/Zygote/KpME9/src/lib/array.jl:219
...
which I’ve truncated because it was very long (I’ve mostly included it here to help others facing the same error).
A simple example of the isolated problem is the following:
using LinearAlgebra
relu(x) = x > 0 ? x : zero(x)
M = fill(1, (1, 1))
gradient(x -> sum(relu.(M * x)), [-1])
which gives the result (nothing,)
. It was given by tkf in #317. He proposes a workaround:
using Zygote: @adjoint
fixnothing(x) = x
@adjoint fixnothing(x) = fixnothing(x), function(y)
if y === nothing || y isa AbstractArray{Nothing}
return (zero(x),)
else
return (y,)
end
end
and tkf gives an example:
julia> gradient(x -> sum(relu.(fixnothing(M * x))), [-1])
([0],)
However, when I try the code, I get
julia> gradient(x -> sum(relu.(fixnothing(M * x))), [-1])
(nothing,)
Can anyone help me understand what’s wrong? I’m using Julia 1.6.0 (downloaded from website) on ubuntu 20.04 and Zygote v. 0.6.9 from the registry.
Thank you in advance