I have a problem that lends itself well to data-parallel training (see, e.g., CHEP 2018 Conference, Sofia, Bulgaria (9-13 July 2018): Scaling studies for deep learning in LArTPC event classification · Indico). I’ve been using a library that goes on top of tensorflow (GitHub - matex-org/matex: Machine Learning Toolkit for Extreme Scale (MaTEx)), but I’m facing some issues with tensorflow on a new machine, and I’m wondering if Flux or one of the other deep learning libs in julia couldn’t do the job as well.

Does anybody have experience with training on multiple nodes in julia? Can I just use DistributedArrays for my data sets somehow and everything magically works?

Any pointers would be appreciated.

I have did that with Flux, running a separate copy of each model on different thread and than averaging the results. It was relatively easy and it was on the end couple lines of code.

I did that, since I need multiplication with large sparse matrices. On the end, I was about three times faster on CPU than Tensorflow on GPU .

Cool! Do you have an example that you could share?

I can share a snippet, but it was like this.

say `model`

is your model, I did

```
function _back!(model,loss,ds)
l = loss(model,ds)
isinf(l.tracker.data) && error("inf in the model");
isnan(l.tracker.data)&& error("nan in the model");
Flux.Tracker.back!(l)
l.tracker.data
end
function Flux.Tracker.back!(pars,models,parss,dss,loss)
foreach(s -> copy!(s,pars),parss)
l = zeros(nthreads())
@threads for i in 1:length(dss)
l[i] += _back!(models[i],loss,dss[i])
end
mean!(pars,parss...)
l
end
models = [deepcopy(model) for i in 1:nthreads()]
parss = map(params,models)
pars = params(model)
Flux.Tracker.back!(pars,models,parss,dss,loss)
```

`dss`

was a vector with data for each thread and and `loss`

was the loss function.

You might need to implement some of the missing functions, like taking average of parameters or copying them, but you should have the idea, how I did it.

Hope this has helped.

Thanks! I’m sure I’ll take a while to get up to speed (total Flux newbie), but I appreciate your help.