My gradients IdDIct keys are GlobalRefs instead of arrays

My gradients sometimes appear with keys of type GlobalRef. I do not understand why and I do not understand how to change that.

Example code:

using Flux

function h(x; bs=bs)
    for b in bs[1:end-1]
        x = x+b
    end
    x = x+bs[end]
    only(x)
end

bs = [ [0.0], [1.0], [2.0] ]
grads = gradient(() -> h(x, bs=bs), params(bs))
@show grads[bs[1]]
@show grads[bs[2]]
@show grads[bs[3]]
for (k, v) in grads.grads
    if k isa GlobalRef
        @show k
        @show v
    end
end

The output shows that the grads.grads IdDict contains the expected gradients for b[1] and b[2] indexed by the “arrays” (I am not sure what the terminology is), but that it contains nothing for b[3]. The same IdDict does, however, contain one key which is a GlobalRef to :(Main.bs), and there is contains nothing for the first two elements, but the expected gradient for the last element.

  1. What should this function look like to give me the expected gradient as grads[bs[3]]?
  2. What determines if the gradient appears with a GlobalRef or array as its key in the IdDict, and what do I need to understand to avoid this problem in the future?

This feels like a bug in Zygote.
I have never seen this occur before.

Maybe open an issue on the Zygote.jl repo

GlobalRefs have been a feature of implicit params for some time now. AIUI Zygote needs them and there was a plan to expose them for updating global variables, but that plan was canned. Usually you don’t see them because Zygote.Grads ignores them when iterating, but everything is retained in the underlying IdDict.

Back to the problem at hand, since only parameter is being used here, it makes sense to use explicit instead of implicit gradients. This will capture each element of bs:

julia> grads = only(gradient(bs -> h(x, bs=bs), bs))
3-element Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}:
 [1.0]
 [1.0]
 [1.0]

julia> map(collect, grads) # just for the nicer show
3-element Vector{Vector{Float64}}:
 [1.0]
 [1.0]
 [1.0]