Adjoint for threaded map (ThreadsX.map)

I am using Zygote to compute the gradient of a function which calls map(f, x), where applying f to each element of the array x is slow. I would like to speed up the gradient computation by using a threaded map such as ThreadsX.map; however, Zygote cannot differentiate through this.

Thus, I would like to write a custom adjoint for ThreadsX.map that also parallelizes the backwards pass (using threads) but don’t know how to do this and would be grateful for any help.
I found this discussion to be helpful – it is essentially the same problem but for ThreadsX.sum. The solution implemented a custom adjoint for ThreadsX.sum by taking the regular rrule for sum and replacing the map and sum calls with ThreadsX versions. I tried to do something analogous to this but couldn’t find the rrule for map in ChainRules.jl.

Alternatively, is there another package that contains a threaded map that is compatible with Zygote?

There is this one, for map(f, ::Tuple), and Zygote has this. The basic idea is quite simple, but there are elaborations to deal with map(f, x, y, z), and to be more efficient in some cases.

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(my_map), f, X::AbstractArray)
    hobbits = my_map(X) do x  # this makes an array of tuples
        y, back = rrule_via_ad(config, f, x)
    end
    Y = map(first, hobbits)
    function map_pullback(dY_raw)
        dY = unthunk(dY_raw)
        # Should really do these in the reverse order
        backevals = my_map(hobbits, dY) do (y, back), dy
            dx, dx = back(dy)
        end
        df = ProjectTo(f)(sum(first, backevals))
        dX = map(last, backevals)
        return (NoTangent(), df, dX)
    end
    return Y, map_pullback
end

my_map(f, Xs...) = map(@show(f), Xs...)

gradient(x -> sum(map(inv, x)), [1,2,3.0])
gradient(x -> sum(my_map(inv, x)), [1,2,3.0])  # dX

gradient(x -> sum(map(z -> z/x, 1:3)), 4.0)
gradient(x -> sum(my_map(z -> z/x, 1:3)), 4.0)  # df
3 Likes