I am performing a distributed calculation via pmap. For context, I have some expensive model which I need to differentiate, and for this I’ve chosen Zygote. Part of the process is defining an “adjoint” for pmap. I have followed the implementation for map, which gives me the below code. This works for one process.
using Distributed
using Zygote
function ∇pmap(cx, wp, f, args...)
ys_and_backs = pmap((args...) -> Zygote._pullback(cx, f, args...), wp, args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
ys, backs = Zygote.unzip(ys_and_backs)
ys, function (Δ)
Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, nothing, Δf_and_args[2:end]...)
end
end
end
Zygote.@adjoint function pmap(f, wp, args::Union{AbstractArray,Tuple}...)
∇pmap(__context__, wp, f, args...)
end
The issue comes in the line:
Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)
Here, backs is returned from the original (forward) pmap call, and it contains a list of functions which are then mapped as f. However, these functions are not necessarily defined on the child processes.
My usual strategy is to pass the data to the children via a call like:
@everywhere backs = $backs
But this doesn’t work because backs is only defined within the function.
I see two possible solutions, though I’m sure there are others (which may involve abandoning pmap entirely):
- Pass the closures created in the first
pmapcall to all (or the necessary) children. I’m not sure how to do this. - Ensure a 1-1 correspondence between the processes used in the forward
pmapcall and in the reversepmapcall, so that the process which generated that closure is used when the closure is needed again. This seems brittle and again, I’m not sure how to implement this.
Would appreciate any thoughts and help! The results from the pmap are then reduced, so if there’s a way to make this work with preduce and friends instead, that would be great.