List comprehension in Zygote

I’m trying to do something like this, (I hope the snippet is enough although it’s incomplete)
prob_primitives = [(θ[i+1], p) for (i, p) in enumerate(primitives)]
where primitives is a Vector. However, using this code in a Zygote.gradient call results in an error MethodError: no method matching iterate(::Nothing).

I’ve tried this code as an alternative,
prob_primitives = collect(zip(θ[2:end], primitives))
but collect constructs the array using mutation which is also unsupported.

Answering this myself: it’s possible by creating a Zygote.Buffer and instantiating this through mutation. I don’t know if it’s the best solution, but it works.

Zygote works fine with array comprehensions. I use it frequently. Here’s a simple example:

using Zygote

x = rand(5)
y = rand(5)

f(y) = sum([x[i] * y[i] for i in eachindex(y)])

Then for my random x,y, I have:

julia> Zygote.gradient(f, y)
([0.20735224892728188, 0.023609460904363777, 0.8209142299297378, 0.5328550686359217, 0.11713217978556711],)

You would get better help here in the forum by posting a self contained minimal reproducible example that produces your error. See PSA: make it easier to help you for more details on getting help.

2 Likes

While the Buffer “solution” does work, it complicates things and doesn’t make for very clean code, so I’d still like comprehension in general to work - but in most cases I still can’t make it.
Here’s an example where I can’t make it work, although it also uses a dictionary which might add another issue to the mix:


using Zygote

w = randn(5)
grads = gradient(Params(w)) do
    dict = Dict{Int, Float64}(i => v for (i,v) in enumerate(w))
    m = maximum(i*v for (i,v) in dict)
    k = randn(5)
    sum(k .- m)
end

results in MethodError: no method matching getindex(::Dict{Any,Any}). Besides, you only get this far with the generator syntax since the array comprehension syntax allocates an array and copies data into it, which is unsupported mutation. I’ve tried a bunch of variations on the maximum() line to no avail.

Indeed, learning Zygote has involved a bit of stepping on landmines of unsupported operations, but once you learn those bits, it’s quite amazing. Here, your problem is differentiating through indexing into a Dict. Here’s a minimal version that works (no dict)

using Zygote

w = randn(5)

grads = gradient(Params([w])) do
    m = maximum(i*v for (i,v) in enumerate(w))
    k = randn(5)
    sum(k .- m)
end
julia> grads[w]
5-element Array{Float64,1}:
   0.0
   0.0
   0.0
   0.0
 -50.0

Note a couple important changes from your version:

  • No Dict. In fact, there is no intermediate data structure to hold the pairs at all.
  • Params(w) should be Params([w]) or Params((w,)) because the argument to Params should be an iterable of objects you want to differentiate w.r.t.

Is it doable with the dictionary, though? My example is a condensed version of some code I actually have, and the data structure I need to iterate over is indeed a dictionary. Unlike the toy case, I can’t easily replace the dict with something else.

I haven’t been able to get any function that constructs a Dict and then accesses it to work with Zygote. But this hasn’t really come up in any applications that I’ve built. Usually, when I need to create key-value mappings and work with them inside of differentiable functions, I use NamedTuples and it works great.

This can be made to work if instead of iterating over the dictionary one iterates over the keys explicitly:

w = randn(5)
grads = Zygote.gradient(Zygote.Params(w)) do
    dict = Dict{Int, Float64}(i => v for (i,v) in enumerate(w))
    # m = maximum(i*v for (i,v) in dict) # doesn't work
    m = maximum(i -> i * dict[i], eachindex(w)) # works
    k = randn(5)
    sum(k .- m)
end

This works because getindex(d::Dict, ...) has adjoints defined in Zygote, whereas I’m not sure the iteration protocol does. In general, I have found that iterating over set-like unordered containers like Dicts doesn’t tend to work unless you reformulate it into iteration over array-like containers (e.g. iterating over eachindex(w) here). I wouldn’t recommend using Dicts this way in performance critical areas, but if you are just returning e.g. Dicts of loss function values that should be fine.

1 Like

Unfortunately, eachindex or keys on a dictionary returns a Base.KeySet which is also a dictionary-like container (e.g. map is undefined). And collect is seemingly unsupported since internally it mutates an array.