Workaround for Zygote.jl issue #317 not working?


I’ve run into the problem described in Zygote bugs #857, #317 and maybe #495. In essence, when broadcasting over a piecewise function where some branches may not have an adjoint, Zygote will try to back-propagate nothings, which gives problems further up (I think. My understanding of Zygote is limited). In my case, it manifested as a trace-back like this:

ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::LLVM.ModuleGlobalSet) at /opt/julia/julia-1.5.3-depot/packages/LLVM/7Q46C/src/core/module.jl:126
  iterate(::LLVM.ModuleGlobalSet, ::Any) at /opt/julia/julia-1.5.3-depot/packages/LLVM/7Q46C/src/core/module.jl:126
  iterate(::Base.AsyncGenerator, ::Base.AsyncGeneratorState) at asyncmap.jl:382
 [1] _zip_iterate_some at ./iterators.jl:352 [inlined]
 [2] _zip_iterate_some at ./iterators.jl:354 [inlined]
 [3] _zip_iterate_all at ./iterators.jl:344 [inlined]
 [4] iterate at ./iterators.jl:334 [inlined]
 [5] iterate at ./generator.jl:44 [inlined]
 [6] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof(∂(λ)),1},Nothing}},Base.var"#3#4"{Zygote.var"#510#514"}}) at ./array.jl:686
 [7] map(::Function, ::Array{typeof(∂(λ)),1}, ::Nothing) at ./abstractarray.jl:2248
 [8] (::Zygote.var"#509#513"{Array{typeof(∂(λ)),1}})(::Nothing) at /opt/julia/julia-1.5.3-depot/packages/Zygote/KpME9/src/lib/array.jl:187
 [9] (::Zygote.var"#540#541"{Zygote.var"#509#513"{Array{typeof(∂(λ)),1}}})(::Nothing) at /opt/julia/julia-1.5.3-depot/packages/Zygote/KpME9/src/lib/array.jl:219

which I’ve truncated because it was very long (I’ve mostly included it here to help others facing the same error).

A simple example of the isolated problem is the following:

using LinearAlgebra
relu(x) = x > 0 ? x : zero(x)
M = fill(1, (1, 1))
gradient(x -> sum(relu.(M * x)), [-1])

which gives the result (nothing,). It was given by tkf in #317. He proposes a workaround:

using Zygote: @adjoint
fixnothing(x) = x
@adjoint fixnothing(x) = fixnothing(x), function(y)
    if y === nothing || y isa AbstractArray{Nothing}
        return (zero(x),)
        return (y,)

and tkf gives an example:

julia> gradient(x -> sum(relu.(fixnothing(M * x))), [-1])

However, when I try the code, I get

julia> gradient(x -> sum(relu.(fixnothing(M * x))), [-1])

Can anyone help me understand what’s wrong? I’m using Julia 1.6.0 (downloaded from website) on ubuntu 20.04 and Zygote v. 0.6.9 from the registry.

Thank you in advance