Julia Distributed, AllReduce, and Distributed Training

I’m working with Distributed Julia on multiple hosts and want to write some code that does distributed training with a machine learning model and distributed data. I don’t want to use MPI. I have the ClusterManager working correctly and can use addproc_mysystem(n) with out any issues.

The traditional way to do distributed training is to put an identical copy of the model on each worker, then partition the training data and give each worker one partition. Each worker will compute a gradient with respect to its portion of the training data, then call AllReduce(Add, gradients)/num_worker. The gradients form AllReduce are what’s used to update the model.

Question 1 is, what’s the best strategy to implement an AllReduce with Julia’s Distributed module?

Question 2, @everywhere Z=1 will let me declare the variable Z in each of the workers Main namespace. But the value of Z is not shared between works (right?). if I do @everwhere ch3 = RemoteChannel(()->Channel{Int}(10), 3) how does Julia know that ch3 in each of the workers namespace references the same underlying RemoteChannel?
As a more complicated example, in order to implement a RingAllReduce, I’m doing
@everywhere channelTable = [(RemoteChannel(()->Channel{Int}(10), w_i), RemoteChannel(()->Channel{Int}(10), w_i)) for w_i in workers()]. channelTable[2][1] indeed seems to reference the same Remote channel no matter where @spawnat 2 put!(channelTable[2][1], 7) and @spawnat 3 take!(channelTable[2][1]) are run. Is this really doing what I think it is? That is passing a value from worker 2 to worker 3.

Not sure if topic categories can be changed after the fact, but this would probably get more traction in #domain:parallel. You may also be interested in GitHub - JuliaParallel/UCX.jl.

Okay, I’ll see what I can do to move it over there.

Short Update,

I got this bit of code working… I think.

using Distributed
using MyClusterManagerAprun

                                                                                                                                                                                                                                                                                                                
addprocs_aprun(16)



@everywhere function RingAllReduce(channelTable, masterChannel)
    m_id = myid()
    n_w = nworkers()
    dat = rand(1:200, (4,4,156,12));
    w_l = sort(workers())
                                                                                                                                                                                                                                                                                                 
    w_i_i_map = Dict(w_i => i for (i,w_i) in zip(0:(n_w-1), w_l))
    i_w_i_map = Dict(i => w_i for (i,w_i) in zip(0:(n_w-1), w_l))

    right_id = i_w_i_map[(w_i_i_map[ m_id] + 1) % n_w]
                                                                                                                                                                                                                           
    _dat_accum = copy(dat)
    time_el = @elapsed for _i in 1:(nworkers()-1)
        put!(channelTable[myid()], _dat_accum)
        _dat_accum = take!(channelTable[right_id]) + dat
    end
    put!(channelTable[myid()], _dat_accum)
    _dat_accum = take!(channelTable[right_id])                                                                                                                                                                                                                                                                                               
    put!(masterChannel, (m_id, time_el))
end
masterChannel = RemoteChannel(()->Channel{Tuple}(20), 1)
channelTable = Dict(w_i => RemoteChannel(()->Channel{Array{Int64,4}}(1), w_i) for w_i in workers())

futures = [remotecall(RingAllReduce, w_i, channelTable, masterChannel) for w_i in workers()]

Then I can use masterChannel to monitor progress like this,

[take!(masterChannel) for i in 1:100 if isready(masterChannel)]

This pattern seems to be working for me at the moment. I’ll post again when I have a better test set up for this.

1 Like