List comprehension in Zygote

This can be made to work if instead of iterating over the dictionary one iterates over the keys explicitly:

w = randn(5)
grads = Zygote.gradient(Zygote.Params(w)) do
    dict = Dict{Int, Float64}(i => v for (i,v) in enumerate(w))
    # m = maximum(i*v for (i,v) in dict) # doesn't work
    m = maximum(i -> i * dict[i], eachindex(w)) # works
    k = randn(5)
    sum(k .- m)
end

This works because getindex(d::Dict, ...) has adjoints defined in Zygote, whereas I’m not sure the iteration protocol does. In general, I have found that iterating over set-like unordered containers like Dicts doesn’t tend to work unless you reformulate it into iteration over array-like containers (e.g. iterating over eachindex(w) here). I wouldn’t recommend using Dicts this way in performance critical areas, but if you are just returning e.g. Dicts of loss function values that should be fine.

1 Like