Data-parallel training with conv nets in Julia

distributed

#1

I have a problem that lends itself well to data-parallel training (see, e.g., https://indico.cern.ch/event/587955/contributions/2937539/). I’ve been using a library that goes on top of tensorflow (https://github.com/matex-org/matex), but I’m facing some issues with tensorflow on a new machine, and I’m wondering if Flux or one of the other deep learning libs in julia couldn’t do the job as well.
Does anybody have experience with training on multiple nodes in julia? Can I just use DistributedArrays for my data sets somehow and everything magically works?
Any pointers would be appreciated.


#2

I have did that with Flux, running a separate copy of each model on different thread and than averaging the results. It was relatively easy and it was on the end couple lines of code.

I did that, since I need multiplication with large sparse matrices. On the end, I was about three times faster on CPU than Tensorflow on GPU .


#3

Cool! Do you have an example that you could share?


#4

I can share a snippet, but it was like this.

say model is your model, I did

function _back!(model,loss,ds)
  l = loss(model,ds)
  isinf(l.tracker.data) && error("inf in the model");
  isnan(l.tracker.data)&& error("nan in the model");
  Flux.Tracker.back!(l)
  l.tracker.data
end

function Flux.Tracker.back!(pars,models,parss,dss,loss)
  foreach(s -> copy!(s,pars),parss)
  l = zeros(nthreads())
  @threads for i in 1:length(dss) 
    l[i] += _back!(models[i],loss,dss[i])
  end
  mean!(pars,parss...)
  l
end

models = [deepcopy(model) for i in 1:nthreads()]
parss = map(params,models)
pars = params(model)
Flux.Tracker.back!(pars,models,parss,dss,loss)

dss was a vector with data for each thread and and loss was the loss function.

You might need to implement some of the missing functions, like taking average of parameters or copying them, but you should have the idea, how I did it.

Hope this has helped.


#5

Thanks! I’m sure I’ll take a while to get up to speed (total Flux newbie), but I appreciate your help.