I am using Zygote to compute the gradient of a function which calls map(f, x)
, where applying f
to each element of the array x
is slow. I would like to speed up the gradient computation by using a threaded map such as ThreadsX.map
; however, Zygote cannot differentiate through this.
Thus, I would like to write a custom adjoint for ThreadsX.map
that also parallelizes the backwards pass (using threads) but don’t know how to do this and would be grateful for any help.
I found this discussion to be helpful – it is essentially the same problem but for ThreadsX.sum
. The solution implemented a custom adjoint for ThreadsX.sum
by taking the regular rrule for sum and replacing the map and sum calls with ThreadsX versions. I tried to do something analogous to this but couldn’t find the rrule for map in ChainRules.jl
.
Alternatively, is there another package that contains a threaded map that is compatible with Zygote?
There is this one, for map(f, ::Tuple)
, and Zygote has this. The basic idea is quite simple, but there are elaborations to deal with map(f, x, y, z)
, and to be more efficient in some cases.
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(my_map), f, X::AbstractArray)
hobbits = my_map(X) do x # this makes an array of tuples
y, back = rrule_via_ad(config, f, x)
end
Y = map(first, hobbits)
function map_pullback(dY_raw)
dY = unthunk(dY_raw)
# Should really do these in the reverse order
backevals = my_map(hobbits, dY) do (y, back), dy
dx, dx = back(dy)
end
df = ProjectTo(f)(sum(first, backevals))
dX = map(last, backevals)
return (NoTangent(), df, dX)
end
return Y, map_pullback
end
my_map(f, Xs...) = map(@show(f), Xs...)
gradient(x -> sum(map(inv, x)), [1,2,3.0])
gradient(x -> sum(my_map(inv, x)), [1,2,3.0]) # dX
gradient(x -> sum(map(z -> z/x, 1:3)), 4.0)
gradient(x -> sum(my_map(z -> z/x, 1:3)), 4.0) # df
3 Likes