Parallel Reductions with Zygote

I am trying to speed up the gradient computation of a model of the form

``````sum(m->m(x), models),
``````

where `models` is an array of constituent models. This is (supposed to be) easily parallelizable, but I am having trouble making it work with Zygote.

First, I tried the parallel `sum` implementation in ThreadsX.jl, but I get worse-than-serial performance for the backward pass on an Intel processor - a 24 core Intel(R) Xeon(R) CPU X5690 @ 3.47GHz - while the forward pass is accelerated. Further, I get a segfault running natively on an Apple M1 Pro, which could be related to some threading issues that are still present on ARM.

Second, writing my own parallel version of `sum(f, x)` that is compatible with Zygote also proved more difficult than excepted due to Zygote not supporting array mutation, which precludes the solutions discussed here.

This leaves writing a custom adjoint for the custom implementation as a still viable option. I don’t yet know how to do this effectively and would be grateful for any help.

Also, here is an implementation of a parallel summation that works well for the forward pass, but that Zygote can’t differentiate through (because of a `try-catch` block apparently included in the `@threads` implementation, and secondarily, because of the array mutation happening in the parallel section).

``````using Base.Threads
function parallel_sum(f, x::AbstractVector)
n, k = length(x), nthreads()
m = n ÷ k
fx1 = f(x[1]) # to get correct output type
sum_x = fill(fx1, k)
@threads for i in 1:k
y = @view x[(i-1)*m+1 : (i*m)]
sum_x[i] = sum(f, y)
end
result = zero(fx1)
for i in 1:k
result += sum_x[i]
end
if rem(n, k) != 0
result += sum(f, @view(x[m*k+1:end]))
end
return result
end
``````

We recently had some success using `Folds` here. Don’t know if this helps with `Zygote` though.

Thank you for the pointer @goerch!

I just tried out `Folds`, but realized it relies on `Transducers`, the same backend as `ThreadsX` and for this reason, has the same performance (Intel) and segfault (ARM) issues as `ThreadsX` in conjunction with `Zygote`.

2 Likes

Naive parallel implementations we discussed here.

Edit: sorry, corrected link.
Edit: just for reference, here my latest parallelized version

``````function naive_kbn_parallel(xs)  # credit to @tkf from another thread
len = length(xs)
ys = Vector{Float64}(undef, 2 * nt)
chunk = (len + nt - 1) ÷ nt
s, c = _naive_kbn(@view xs[(i - 1) * chunk + 1:min(i * chunk, len)])
ys[2 * i - 1] = s
ys[2 * i] = c
end
naive_kbn_serial(ys)
end
``````
1 Like

Is it `sum(m, models)` or `sum(m -> m(x), models)`? Either way I presume the gradient with respect to `m` is important.

The rule which handles this is here: mapreduce.jl#L74-L99 and the key lines are:

``````     fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
y = sum(first, fx_and_pullbacks; dims=dims)  # forward

# f̄_and_x̄s = map(((_, back),) -> back(ȳ), fx_and_pullbacks)  for sum without dims
``````

These two `map` steps are the work you could parallelise, by hand or with something like `ThreadsX.map`. My guess is that hoping for the gradient of `ThreadsX.sum` to just work, and be efficient, is asking too much, although I’d like to be wrong.

1 Like

Thank you for the response @mcabbott!

Is it `sum(m, models)` or `sum(m -> m(x), models)` ? Either way I presume the gradient with respect to `m` is important.

It is `sum(m -> m(x), models)`, I just edited the original post to correct this.

I’ll give this a deeper look and hope I can parallelize it efficiently!

Naive parallel implementations we discussed here.

I am having trouble opening your link. Could the link field be empty?

I guess we need FoldsChainRules.jl or sth and simply port the rrule for the sequential sum. Many clever things are possible but that looks like a good start. I think it’s better to do this at Folds.jl than ThreadsX.jl so that we get Distributed/Dagger-based definition “for free” as well. Adding AD support to JuliaFolds is rather overdue… But my excuse always has been, “yeah but there are things like Tullio.jl.” @mcabbott Does Tullio.jl handle this case?

4 Likes

I just implemented a custom adjoint for `ThreadsX.sum` by replacing the `map` and `sum` calls in the regular `rrule` with `ThreadsX` versions. I am able to achieve a speedup of a factor of 6.5 using 10 threads on an Apple M1 Pro in my use case (and no segfault). I am attaching the implementation. Thank you for the helpful pointer @mcabbot!

``````using ThreadsX
using ChainRules
using ChainRules: RuleConfig, HasReverseMode, rrule_via_ad, ProjectTo, NoTangent, unthunk
function ChainRules.rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(ThreadsX.sum), f, xs::AbstractArray)
y = ThreadsX.sum(first, fx_and_pullbacks)

pullbacks = ThreadsX.map(last, fx_and_pullbacks)

project = ProjectTo(xs)

function sum_pullback(ȳ)
call(f, x) = f(x)
# if dims is :, then need only left-handed only broadcast
# broadcast_ȳ = dims isa Colon  ? (ȳ,) : ȳ
f̄_and_x̄s = ThreadsX.map(f->f(ȳ), pullbacks)
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
NoTangent()
else
end
x̄s = ThreadsX.map(unthunk ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
return NoTangent(), f̄, project(x̄s)
end
return y, sum_pullback
end
``````
3 Likes

Just as a general note, if you are using the M1 and are using threading I would strongly recommend using the nightly for Julia 1.8. The spurious seqfaults and hangs with threading got fixed there recently.

2 Likes

Edit: This issue got resolved and was unrelated to the parallel code, but a bug Zygote currently exhibits with mutable structs. The implementation above works.

An update:

While the custom adjoint accelerates the backward pass and executes reliably on test problems, it leads to a non-deterministic `AssertionError` in Zygote’s code with my more complex model. I opened an issue about it here.

Also, the segfaults on ARM are indeed fixed on the nightly for Julia 1.8!

1 Like