I am performing a distributed calculation via pmap
. For context, I have some expensive model which I need to differentiate, and for this I’ve chosen Zygote. Part of the process is defining an “adjoint” for pmap
. I have followed the implementation for map
, which gives me the below code. This works for one process.
using Distributed
using Zygote
function ∇pmap(cx, wp, f, args...)
ys_and_backs = pmap((args...) -> Zygote._pullback(cx, f, args...), wp, args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
ys, backs = Zygote.unzip(ys_and_backs)
ys, function (Δ)
Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, nothing, Δf_and_args[2:end]...)
end
end
end
Zygote.@adjoint function pmap(f, wp, args::Union{AbstractArray,Tuple}...)
∇pmap(__context__, wp, f, args...)
end
The issue comes in the line:
Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)
Here, backs
is returned from the original (forward) pmap
call, and it contains a list of functions which are then mapped as f
. However, these functions are not necessarily defined on the child processes.
My usual strategy is to pass the data to the children via a call like:
@everywhere backs = $backs
But this doesn’t work because backs
is only defined within the function.
I see two possible solutions, though I’m sure there are others (which may involve abandoning pmap
entirely):
- Pass the closures created in the first
pmap
call to all (or the necessary) children. I’m not sure how to do this. - Ensure a 1-1 correspondence between the processes used in the forward
pmap
call and in the reversepmap
call, so that the process which generated that closure is used when the closure is needed again. This seems brittle and again, I’m not sure how to implement this.
Would appreciate any thoughts and help! The results from the pmap
are then reduced, so if there’s a way to make this work with preduce
and friends instead, that would be great.