A chainrule for multi-threaded map

Hi All,

does anyone has a a chainrule for multi-threaded map? It should be “relatively” easy to implement by adapting ChainRules.jl/src/rulesets/Base/base.jl at 955941ec31aab80942658f517a2e024b3e7c812b · JuliaDiff/ChainRules.jl · GitHub for multi-threaded map of choice, but I admit that I would rather take someone else’s solution.

Thanks in advance.
Tomas

I did a quick& dirt solution.
I took tmap from ThreadTools and adapted the rrule from ChainRules as follows

function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(ThreadTools.tmap), f::F, xs::Tuple...) where {F}
    length_y = minimum(length, xs)
    hobbits = ntuple(length_y) do i
        args = getindex.(xs, i)
        rrule_via_ad(config, f, args...)
    end
    y = ThreadTools.tmap(first, hobbits)
    num_xs = Val(length(xs))
    paddings = map(x -> ntuple(Returns(NoTangent()), (length(x) - length_y)), xs)
    all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mistmatched lengths!
        But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed."""
    function map_pullback(dy_raw)
        dy = unthunk(dy_raw)
        # We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass:
        backevals = ntuple(length_y) do i
            rev_i = length_y - i + 1
            last(hobbits[rev_i])(dy[rev_i])
        end |> reverse
        # This df doesn't infer, could test Base.issingletontype(F), but it's not the only inference problem.
        df = ProjectTo(f)(sum(first, backevals))
        # Now unzip that. Because `map` like `zip` should when any `x` stops, some `dx`s may need padding.
        # Although in fact, `map(+, (1,2), (3,4,5))` is an error... https://github.com/JuliaLang/julia/issues/42216
        dxs = ntuple(num_xs) do k
            dx_short = ThreadTools.tmap(bv -> bv[k+1], backevals)
            ProjectTo(xs[k])((dx_short..., paddings[k]...))  # ProjectTo makes the Tangent for us
        end
        return (NoTangent(), df, dxs...)
    end
    map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
    return y, map_pullback
end

The rule is almost unchaged from the original, except maps were replaced by tmaps. Sometimes, the power and simplicity of is just stunning.

If f is expensive, then this rrule_via_ad(config, f, args...) is the bit which needs multi-threading (in forward pass, and last(hobbits[rev_i])(dy[rev_i]) in reverse) while map(first, hobbits) should be almost free.

Thanks for your comment. It has forced me to dig a bit deeper and understand the rule better. My new version is less general version of the original, looking as

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(ThreadTools.tmap), f::F, xs) where {F}
    hobbits = tmap(xs) do x
        rrule_via_ad(config, f, x)
    end
    y = map(first, hobbits)
    length_y = length(y)
    num_xs = Val(length(xs))
    function map_pullback(dy_raw)
        dy = unthunk(dy_raw)
        # We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass:
        backevals = tmap(1:length_y) do i
            rev_i = length_y - i + 1
            last(hobbits[rev_i])(dy[rev_i])
        end |> reverse

        df = ProjectTo(f)(sum(first, backevals))

        dx_short = map(bv -> bv[2], backevals)
        dx = ProjectTo(xs)(dx_short)  # ProjectTo makes the Tangent for us
        return (NoTangent(), df, dx)
    end
    return y, map_pullback
end

I have not thought how aware of rrule_via_ad works.