List comprehension in Zygote

I’m trying to do something like this, (I hope the snippet is enough although it’s incomplete)
prob_primitives = [(θ[i+1], p) for (i, p) in enumerate(primitives)]
where primitives is a Vector. However, using this code in a Zygote.gradient call results in an error MethodError: no method matching iterate(::Nothing).

I’ve tried this code as an alternative,
prob_primitives = collect(zip(θ[2:end], primitives))
but collect constructs the array using mutation which is also unsupported.

Answering this myself: it’s possible by creating a Zygote.Buffer and instantiating this through mutation. I don’t know if it’s the best solution, but it works.

Zygote works fine with array comprehensions. I use it frequently. Here’s a simple example:

using Zygote

x = rand(5)
y = rand(5)

f(y) = sum([x[i] * y[i] for i in eachindex(y)])

Then for my random x,y, I have:

julia> Zygote.gradient(f, y)
([0.20735224892728188, 0.023609460904363777, 0.8209142299297378, 0.5328550686359217, 0.11713217978556711],)

You would get better help here in the forum by posting a self contained minimal reproducible example that produces your error. See PSA: make it easier to help you for more details on getting help.

2 Likes

While the Buffer “solution” does work, it complicates things and doesn’t make for very clean code, so I’d still like comprehension in general to work - but in most cases I still can’t make it.
Here’s an example where I can’t make it work, although it also uses a dictionary which might add another issue to the mix:


using Zygote

w = randn(5)
grads = gradient(Params(w)) do
    dict = Dict{Int, Float64}(i => v for (i,v) in enumerate(w))
    m = maximum(i*v for (i,v) in dict)
    k = randn(5)
    sum(k .- m)
end

results in MethodError: no method matching getindex(::Dict{Any,Any}). Besides, you only get this far with the generator syntax since the array comprehension syntax allocates an array and copies data into it, which is unsupported mutation. I’ve tried a bunch of variations on the maximum() line to no avail.

Indeed, learning Zygote has involved a bit of stepping on landmines of unsupported operations, but once you learn those bits, it’s quite amazing. Here, your problem is differentiating through indexing into a Dict. Here’s a minimal version that works (no dict)

using Zygote

w = randn(5)

grads = gradient(Params([w])) do
    m = maximum(i*v for (i,v) in enumerate(w))
    k = randn(5)
    sum(k .- m)
end
julia> grads[w]
5-element Array{Float64,1}:
   0.0
   0.0
   0.0
   0.0
 -50.0

Note a couple important changes from your version:

  • No Dict. In fact, there is no intermediate data structure to hold the pairs at all.
  • Params(w) should be Params([w]) or Params((w,)) because the argument to Params should be an iterable of objects you want to differentiate w.r.t.

Is it doable with the dictionary, though? My example is a condensed version of some code I actually have, and the data structure I need to iterate over is indeed a dictionary. Unlike the toy case, I can’t easily replace the dict with something else.

I haven’t been able to get any function that constructs a Dict and then accesses it to work with Zygote. But this hasn’t really come up in any applications that I’ve built. Usually, when I need to create key-value mappings and work with them inside of differentiable functions, I use NamedTuples and it works great.