I’m trying to do something like this, (I hope the snippet is enough although it’s incomplete) prob_primitives = [(θ[i+1], p) for (i, p) in enumerate(primitives)]
where primitives is a Vector. However, using this code in a Zygote.gradient call results in an error MethodError: no method matching iterate(::Nothing).
I’ve tried this code as an alternative, prob_primitives = collect(zip(θ[2:end], primitives))
but collect constructs the array using mutation which is also unsupported.
Answering this myself: it’s possible by creating a Zygote.Buffer and instantiating this through mutation. I don’t know if it’s the best solution, but it works.
You would get better help here in the forum by posting a self contained minimal reproducible example that produces your error. See Please read: make it easier to help you for more details on getting help.
While the Buffer “solution” does work, it complicates things and doesn’t make for very clean code, so I’d still like comprehension in general to work - but in most cases I still can’t make it.
Here’s an example where I can’t make it work, although it also uses a dictionary which might add another issue to the mix:
using Zygote
w = randn(5)
grads = gradient(Params(w)) do
dict = Dict{Int, Float64}(i => v for (i,v) in enumerate(w))
m = maximum(i*v for (i,v) in dict)
k = randn(5)
sum(k .- m)
end
results in MethodError: no method matching getindex(::Dict{Any,Any}). Besides, you only get this far with the generator syntax since the array comprehension syntax allocates an array and copies data into it, which is unsupported mutation. I’ve tried a bunch of variations on the maximum() line to no avail.
Indeed, learning Zygote has involved a bit of stepping on landmines of unsupported operations, but once you learn those bits, it’s quite amazing. Here, your problem is differentiating through indexing into a Dict. Here’s a minimal version that works (no dict)
using Zygote
w = randn(5)
grads = gradient(Params([w])) do
m = maximum(i*v for (i,v) in enumerate(w))
k = randn(5)
sum(k .- m)
end
Is it doable with the dictionary, though? My example is a condensed version of some code I actually have, and the data structure I need to iterate over is indeed a dictionary. Unlike the toy case, I can’t easily replace the dict with something else.
I haven’t been able to get any function that constructs a Dict and then accesses it to work with Zygote. But this hasn’t really come up in any applications that I’ve built. Usually, when I need to create key-value mappings and work with them inside of differentiable functions, I use NamedTuples and it works great.
This can be made to work if instead of iterating over the dictionary one iterates over the keys explicitly:
w = randn(5)
grads = Zygote.gradient(Zygote.Params(w)) do
dict = Dict{Int, Float64}(i => v for (i,v) in enumerate(w))
# m = maximum(i*v for (i,v) in dict) # doesn't work
m = maximum(i -> i * dict[i], eachindex(w)) # works
k = randn(5)
sum(k .- m)
end
This works because getindex(d::Dict, ...) has adjoints defined in Zygote, whereas I’m not sure the iteration protocol does. In general, I have found that iterating over set-like unordered containers like Dicts doesn’t tend to work unless you reformulate it into iteration over array-like containers (e.g. iterating over eachindex(w) here). I wouldn’t recommend using Dicts this way in performance critical areas, but if you are just returning e.g. Dicts of loss function values that should be fine.
Unfortunately, eachindex or keys on a dictionary returns a Base.KeySet which is also a dictionary-like container (e.g. map is undefined). And collect is seemingly unsupported since internally it mutates an array.