Passing constructed closures to child processes

I am performing a distributed calculation via pmap. For context, I have some expensive model which I need to differentiate, and for this I’ve chosen Zygote. Part of the process is defining an “adjoint” for pmap. I have followed the implementation for map, which gives me the below code. This works for one process.

using Distributed
using Zygote
   
function ∇pmap(cx, wp, f, args...)
  ys_and_backs = pmap((args...) -> Zygote._pullback(cx, f, args...), wp, args...)
  if isempty(ys_and_backs)
    ys_and_backs, _ -> nothing
  else
    ys, backs = Zygote.unzip(ys_and_backs)
    ys, function (Δ)
      Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)
      Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
      Δf = reduce(Zygote.accum, Δf_and_args[1])
      (Δf, nothing, Δf_and_args[2:end]...)
    end
  end
end

Zygote.@adjoint function pmap(f, wp, args::Union{AbstractArray,Tuple}...)
  ∇pmap(__context__, wp, f, args...)
end

The issue comes in the line:

Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)

Here, backs is returned from the original (forward) pmap call, and it contains a list of functions which are then mapped as f. However, these functions are not necessarily defined on the child processes.

My usual strategy is to pass the data to the children via a call like:

@everywhere backs = $backs

But this doesn’t work because backs is only defined within the function.

I see two possible solutions, though I’m sure there are others (which may involve abandoning pmap entirely):

  1. Pass the closures created in the first pmap call to all (or the necessary) children. I’m not sure how to do this.
  2. Ensure a 1-1 correspondence between the processes used in the forward pmap call and in the reverse pmap call, so that the process which generated that closure is used when the closure is needed again. This seems brittle and again, I’m not sure how to implement this.

Would appreciate any thoughts and help! The results from the pmap are then reduced, so if there’s a way to make this work with preduce and friends instead, that would be great.

An update on this:

I tried breaking down my example, but couldn’t reproduce the issue in anything but my real use case. For instance, the following works:

using Distributed
@everywhere begin
  using Zygote
  function f_pmap_zygote_solve(A, bs)
    xs = pmap((b) -> A \ b, wp, bs)
    return sum(sum(xs))
  end
end
wp = default_worker_pool()
A = sprand(200, 200, 0.01) + 200*I
b0s = [randn(200) for i=1:10]
Zygote.gradient(f_pmap_zygote_solve, A, b0s)

with the adjoint of pmap defined above. I still think that my problem is still the closure, as the error returned is:

ERROR: LoadError: UndefVarError: ##493#back#178 not defined

then a long list of deserialization steps, culminating in a call of an anonymous function:

deserialize_datatype(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Bool) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:1115
handle_deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Int32) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:771
deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:731
deserialize_datatype(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Bool) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:1134
handle_deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Int32) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:771
deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:731
deserialize_datatype(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Bool) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:1139
handle_deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Int32) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:771
deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:731
handle_deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Int32) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:778
deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:731
(::getfield(Serialization, Symbol("##3#4")){Distributed.ClusterSerializer{Sockets.TCPSocket}})(::Int64) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:869
ntuple(::getfield(Serialization, Symbol("##3#4")){Distributed.ClusterSerializer{Sockets.TCPSocket}}, ::Int64) at ./tuple.jl:136
deserialize_tuple(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Int64) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:869
handle_deserialize(::Distributed.ClusterSerializer{Sockets.TCPSocket}, ::Int32) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:761
deserialize_msg(::Distributed.ClusterSerializer{Sockets.TCPSocket}) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Serialization/src/Serialization.jl:731
#invokelatest#1 at ./essentials.jl:742 [inlined]
invokelatest at ./essentials.jl:741 [inlined]
message_handler_loop(::Sockets.TCPSocket, ::Sockets.TCPSocket, ::Bool) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Distributed/src/process_messages.jl:160
process_tcp_streams(::Sockets.TCPSocket, ::Sockets.TCPSocket, ::Bool) at /home/ubuntu/julia/usr/share/julia/stdlib/v1.1/Distributed/src/process_messages.jl:117
(::getfield(Distributed, Symbol("##105#106")){Sockets.TCPSocket,Sockets.TCPSocket,Bool})() at ./task.jl:259
Stacktrace:
 [1] (::getfield(Base, Symbol("##696#698")))(::Task) at ./asyncmap.jl:178

I was able to create a M(B)E (I think I wasn’t adding enough processors last time)

using Distributed
addprocs(4, enable_threaded_blas=true) # fails with enable_threaded_blas=false as well

@everywhere using Zygote

@everywhere begin
  function ∇pmap(cx, wp, f, args...)
    ys_and_backs = pmap((args...) -> Zygote._pullback(cx, f, args...), wp, args...)
    if isempty(ys_and_backs)
      ys_and_backs, _ -> nothing
    else
      ys, backs = Zygote.unzip(ys_and_backs)
      ys, function (Δ)
        Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)
        Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
        Δf = reduce(Zygote.accum, Δf_and_args[1])
        (Δf, nothing, Δf_and_args[2:end]...)
      end
    end
  end

  Zygote.@adjoint function pmap(f, wp, args::Union{AbstractArray,Tuple}...)
    ∇pmap(__context__, wp, f, args...)
  end

  function test_grad(x)
    return x^2 + 3log(x)
  end
  function test_grad_pmap(x)
    return sum(pmap(test_grad, wp, x))
  end
end

wp = default_worker_pool()
Zygote.gradient(test_grad, 1.0) # works
Zygote.gradient(test_grad_pmap, rand(100)) # breaks
Zygote.gradient(x -> sum(pmap(y -> y^2, wp, x)), rand(100)) # breaks
Zygote.gradient(x -> sum(pmap(sum, wp, x)), rand(100)) # works

I installed Julia 1.4 and the code worked immediately! (Previously I was trying 1.1). Hats off to the dev team and Zygote for anticipating and solving this issue months ago :slight_smile: