I did a quick& dirt solution.
I took tmap from ThreadTools and adapted the rrule from ChainRules as follows
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(ThreadTools.tmap), f::F, xs::Tuple...) where {F}
length_y = minimum(length, xs)
hobbits = ntuple(length_y) do i
args = getindex.(xs, i)
rrule_via_ad(config, f, args...)
end
y = ThreadTools.tmap(first, hobbits)
num_xs = Val(length(xs))
paddings = map(x -> ntuple(Returns(NoTangent()), (length(x) - length_y)), xs)
all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mistmatched lengths!
But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed."""
function map_pullback(dy_raw)
dy = unthunk(dy_raw)
# We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass:
backevals = ntuple(length_y) do i
rev_i = length_y - i + 1
last(hobbits[rev_i])(dy[rev_i])
end |> reverse
# This df doesn't infer, could test Base.issingletontype(F), but it's not the only inference problem.
df = ProjectTo(f)(sum(first, backevals))
# Now unzip that. Because `map` like `zip` should when any `x` stops, some `dx`s may need padding.
# Although in fact, `map(+, (1,2), (3,4,5))` is an error... https://github.com/JuliaLang/julia/issues/42216
dxs = ntuple(num_xs) do k
dx_short = ThreadTools.tmap(bv -> bv[k+1], backevals)
ProjectTo(xs[k])((dx_short..., paddings[k]...)) # ProjectTo makes the Tangent for us
end
return (NoTangent(), df, dxs...)
end
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
return y, map_pullback
end
The rule is almost unchaged from the original, except maps were replaced by tmaps. Sometimes, the power and simplicity of is just stunning.