Hi All,
does anyone has a a chainrule for multi-threaded map? It should be “relatively” easy to implement by adapting ChainRules.jl/src/rulesets/Base/base.jl at 955941ec31aab80942658f517a2e024b3e7c812b · JuliaDiff/ChainRules.jl · GitHub for multi-threaded map of choice, but I admit that I would rather take someone else’s solution.
Thanks in advance.
Tomas
I did a quick& dirt solution.
I took tmap
from ThreadTools
and adapted the rrule
from ChainRules
as follows
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(ThreadTools.tmap), f::F, xs::Tuple...) where {F}
length_y = minimum(length, xs)
hobbits = ntuple(length_y) do i
args = getindex.(xs, i)
rrule_via_ad(config, f, args...)
end
y = ThreadTools.tmap(first, hobbits)
num_xs = Val(length(xs))
paddings = map(x -> ntuple(Returns(NoTangent()), (length(x) - length_y)), xs)
all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mistmatched lengths!
But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed."""
function map_pullback(dy_raw)
dy = unthunk(dy_raw)
# We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass:
backevals = ntuple(length_y) do i
rev_i = length_y - i + 1
last(hobbits[rev_i])(dy[rev_i])
end |> reverse
# This df doesn't infer, could test Base.issingletontype(F), but it's not the only inference problem.
df = ProjectTo(f)(sum(first, backevals))
# Now unzip that. Because `map` like `zip` should when any `x` stops, some `dx`s may need padding.
# Although in fact, `map(+, (1,2), (3,4,5))` is an error... https://github.com/JuliaLang/julia/issues/42216
dxs = ntuple(num_xs) do k
dx_short = ThreadTools.tmap(bv -> bv[k+1], backevals)
ProjectTo(xs[k])((dx_short..., paddings[k]...)) # ProjectTo makes the Tangent for us
end
return (NoTangent(), df, dxs...)
end
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
return y, map_pullback
end
The rule is almost unchaged from the original, except maps were replaced by tmaps
. Sometimes, the power and simplicity of is just stunning.
If f
is expensive, then this rrule_via_ad(config, f, args...)
is the bit which needs multi-threading (in forward pass, and last(hobbits[rev_i])(dy[rev_i])
in reverse) while map(first, hobbits)
should be almost free.
Thanks for your comment. It has forced me to dig a bit deeper and understand the rule better. My new version is less general version of the original, looking as
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(ThreadTools.tmap), f::F, xs) where {F}
hobbits = tmap(xs) do x
rrule_via_ad(config, f, x)
end
y = map(first, hobbits)
length_y = length(y)
num_xs = Val(length(xs))
function map_pullback(dy_raw)
dy = unthunk(dy_raw)
# We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass:
backevals = tmap(1:length_y) do i
rev_i = length_y - i + 1
last(hobbits[rev_i])(dy[rev_i])
end |> reverse
df = ProjectTo(f)(sum(first, backevals))
dx_short = map(bv -> bv[2], backevals)
dx = ProjectTo(xs)(dx_short) # ProjectTo makes the Tangent for us
return (NoTangent(), df, dx)
end
return y, map_pullback
end
I have not thought how aware of rrule_via_ad
works.