Workaround for Zygote.jl issue #317 not working?

Hi,

I’ve run into the problem described in Zygote bugs #857, #317 and maybe #495. In essence, when broadcasting over a piecewise function where some branches may not have an adjoint, Zygote will try to back-propagate nothings, which gives problems further up (I think. My understanding of Zygote is limited). In my case, it manifested as a trace-back like this:

ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::LLVM.ModuleGlobalSet) at /opt/julia/julia-1.5.3-depot/packages/LLVM/7Q46C/src/core/module.jl:126
  iterate(::LLVM.ModuleGlobalSet, ::Any) at /opt/julia/julia-1.5.3-depot/packages/LLVM/7Q46C/src/core/module.jl:126
  iterate(::Base.AsyncGenerator, ::Base.AsyncGeneratorState) at asyncmap.jl:382
  ...
Stacktrace:
 [1] _zip_iterate_some at ./iterators.jl:352 [inlined]
 [2] _zip_iterate_some at ./iterators.jl:354 [inlined]
 [3] _zip_iterate_all at ./iterators.jl:344 [inlined]
 [4] iterate at ./iterators.jl:334 [inlined]
 [5] iterate at ./generator.jl:44 [inlined]
 [6] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof(∂(λ)),1},Nothing}},Base.var"#3#4"{Zygote.var"#510#514"}}) at ./array.jl:686
 [7] map(::Function, ::Array{typeof(∂(λ)),1}, ::Nothing) at ./abstractarray.jl:2248
 [8] (::Zygote.var"#509#513"{Array{typeof(∂(λ)),1}})(::Nothing) at /opt/julia/julia-1.5.3-depot/packages/Zygote/KpME9/src/lib/array.jl:187
 [9] (::Zygote.var"#540#541"{Zygote.var"#509#513"{Array{typeof(∂(λ)),1}}})(::Nothing) at /opt/julia/julia-1.5.3-depot/packages/Zygote/KpME9/src/lib/array.jl:219
...

which I’ve truncated because it was very long (I’ve mostly included it here to help others facing the same error).

A simple example of the isolated problem is the following:

using LinearAlgebra
relu(x) = x > 0 ? x : zero(x)
M = fill(1, (1, 1))
gradient(x -> sum(relu.(M * x)), [-1])

which gives the result (nothing,). It was given by tkf in #317. He proposes a workaround:

using Zygote: @adjoint
fixnothing(x) = x
@adjoint fixnothing(x) = fixnothing(x), function(y)
    if y === nothing || y isa AbstractArray{Nothing}
        return (zero(x),)
    else
        return (y,)
    end
end

and tkf gives an example:

julia> gradient(x -> sum(relu.(fixnothing(M * x))), [-1])
([0],)

However, when I try the code, I get

julia> gradient(x -> sum(relu.(fixnothing(M * x))), [-1])
(nothing,)

Can anyone help me understand what’s wrong? I’m using Julia 1.6.0 (downloaded from website) on ubuntu 20.04 and Zygote v. 0.6.9 from the registry.

Thank you in advance

Per Zygote.jl tries to back-propagate nothings when branching happens inside broadcasting · Issue #317 · FluxML/Zygote.jl · GitHub, this seems to be fixed:

julia> gradient(x -> sum(relu.(M * x)), [-1])
([0.0],)

julia> gradient(x -> sum(relu.(M * x)), [1])
([1.0],)