Parallel Reductions with Zygote

I am trying to speed up the gradient computation of a model of the form

sum(m->m(x), models),

where models is an array of constituent models. This is (supposed to be) easily parallelizable, but I am having trouble making it work with Zygote.

First, I tried the parallel sum implementation in ThreadsX.jl, but I get worse-than-serial performance for the backward pass on an Intel processor - a 24 core Intel(R) Xeon(R) CPU X5690 @ 3.47GHz - while the forward pass is accelerated. Further, I get a segfault running natively on an Apple M1 Pro, which could be related to some threading issues that are still present on ARM.

Second, writing my own parallel version of sum(f, x) that is compatible with Zygote also proved more difficult than excepted due to Zygote not supporting array mutation, which precludes the solutions discussed here.

This leaves writing a custom adjoint for the custom implementation as a still viable option. I don’t yet know how to do this effectively and would be grateful for any help.

Also, here is an implementation of a parallel summation that works well for the forward pass, but that Zygote can’t differentiate through (because of a try-catch block apparently included in the @threads implementation, and secondarily, because of the array mutation happening in the parallel section).

using Base.Threads
function parallel_sum(f, x::AbstractVector)
    n, k = length(x), nthreads()
    m = n ÷ k
    fx1 = f(x[1]) # to get correct output type
    sum_x = fill(fx1, k)
    @threads for i in 1:k
        y = @view x[(i-1)*m+1 : (i*m)]
        sum_x[i] = sum(f, y)
    result = zero(fx1)
    for i in 1:k
        result += sum_x[i]
    if rem(n, k) != 0
        result += sum(f, @view(x[m*k+1:end]))
    return result

We recently had some success using Folds here. Don’t know if this helps with Zygote though.

Thank you for the pointer @goerch!

I just tried out Folds, but realized it relies on Transducers, the same backend as ThreadsX and for this reason, has the same performance (Intel) and segfault (ARM) issues as ThreadsX in conjunction with Zygote.


Naive parallel implementations we discussed here.

Edit: sorry, corrected link.
Edit: just for reference, here my latest parallelized version

function naive_kbn_parallel(xs)  # credit to @tkf from another thread
    len = length(xs)
    nt = min(Threads.nthreads(), len)
    ys = Vector{Float64}(undef, 2 * nt)
    chunk = (len + nt - 1) ÷ nt
    Threads.@threads for i in 1:nt
        s, c = _naive_kbn(@view xs[(i - 1) * chunk + 1:min(i * chunk, len)])
        ys[2 * i - 1] = s
        ys[2 * i] = c
1 Like

Is it sum(m, models) or sum(m -> m(x), models)? Either way I presume the gradient with respect to m is important.

The rule which handles this is here: mapreduce.jl#L74-L99 and the key lines are:

     fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
     y = sum(first, fx_and_pullbacks; dims=dims)  # forward

     f̄_and_x̄s = call.(pullbacks, broadcast_ȳ)   # gradient
     # f̄_and_x̄s = map(((_, back),) -> back(ȳ), fx_and_pullbacks)  for sum without dims

These two map steps are the work you could parallelise, by hand or with something like My guess is that hoping for the gradient of ThreadsX.sum to just work, and be efficient, is asking too much, although I’d like to be wrong.

1 Like

Thank you for the response @mcabbott!

Is it sum(m, models) or sum(m -> m(x), models) ? Either way I presume the gradient with respect to m is important.

It is sum(m -> m(x), models), I just edited the original post to correct this.

I’ll give this a deeper look and hope I can parallelize it efficiently!

Naive parallel implementations we discussed here.

I am having trouble opening your link. Could the link field be empty?

I guess we need FoldsChainRules.jl or sth and simply port the rrule for the sequential sum. Many clever things are possible but that looks like a good start. I think it’s better to do this at Folds.jl than ThreadsX.jl so that we get Distributed/Dagger-based definition “for free” as well. Adding AD support to JuliaFolds is rather overdue… But my excuse always has been, “yeah but there are things like Tullio.jl.” @mcabbott Does Tullio.jl handle this case?


I just implemented a custom adjoint for ThreadsX.sum by replacing the map and sum calls in the regular rrule with ThreadsX versions. I am able to achieve a speedup of a factor of 6.5 using 10 threads on an Apple M1 Pro in my use case (and no segfault). I am attaching the implementation. Thank you for the helpful pointer @mcabbot!

using ThreadsX
using ChainRules
using ChainRules: RuleConfig, HasReverseMode, rrule_via_ad, ProjectTo, NoTangent, unthunk
function ChainRules.rrule(
    config::RuleConfig{>:HasReverseMode}, ::typeof(ThreadsX.sum), f, xs::AbstractArray)
    fx_and_pullbacks =>rrule_via_ad(config, f, x), xs)
    y = ThreadsX.sum(first, fx_and_pullbacks)

    pullbacks =, fx_and_pullbacks)

    project = ProjectTo(xs)

    function sum_pullback(ȳ)
        call(f, x) = f(x)
        # if dims is :, then need only left-handed only broadcast
        # broadcast_ȳ = dims isa Colon  ? (ȳ,) : ȳ
        broadcast_ȳ = ȳ
        f̄_and_x̄s =>f(ȳ), pullbacks)
        # no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
        f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
            ThreadsX.sum(first, f̄_and_x̄s)
        x̄s = ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
        return NoTangent(), f̄, project(x̄s)
    return y, sum_pullback

Just as a general note, if you are using the M1 and are using threading I would strongly recommend using the nightly for Julia 1.8. The spurious seqfaults and hangs with threading got fixed there recently.


Edit: This issue got resolved and was unrelated to the parallel code, but a bug Zygote currently exhibits with mutable structs. The implementation above works.

An update:

While the custom adjoint accelerates the backward pass and executes reliably on test problems, it leads to a non-deterministic AssertionError in Zygote’s code with my more complex model. I opened an issue about it here.

Also, the segfaults on ARM are indeed fixed on the nightly for Julia 1.8!

1 Like