I am using Zygote to compute the gradient of a function which calls `map(f, x)`

, where applying `f`

to each element of the array `x`

is slow. I would like to speed up the gradient computation by using a threaded map such as `ThreadsX.map`

; however, Zygote cannot differentiate through this.

Thus, I would like to write a custom adjoint for `ThreadsX.map`

that also parallelizes the backwards pass (using threads) but don’t know how to do this and would be grateful for any help.

I found this discussion to be helpful – it is essentially the same problem but for `ThreadsX.sum`

. The solution implemented a custom adjoint for `ThreadsX.sum`

by taking the regular rrule for sum and replacing the map and sum calls with ThreadsX versions. I tried to do something analogous to this but couldn’t find the rrule for map in `ChainRules.jl`

.

Alternatively, is there another package that contains a threaded map that is compatible with Zygote?

There is this one, for `map(f, ::Tuple)`

, and Zygote has this. The basic idea is quite simple, but there are elaborations to deal with `map(f, x, y, z)`

, and to be more efficient in some cases.

```
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(my_map), f, X::AbstractArray)
hobbits = my_map(X) do x # this makes an array of tuples
y, back = rrule_via_ad(config, f, x)
end
Y = map(first, hobbits)
function map_pullback(dY_raw)
dY = unthunk(dY_raw)
# Should really do these in the reverse order
backevals = my_map(hobbits, dY) do (y, back), dy
dx, dx = back(dy)
end
df = ProjectTo(f)(sum(first, backevals))
dX = map(last, backevals)
return (NoTangent(), df, dX)
end
return Y, map_pullback
end
my_map(f, Xs...) = map(@show(f), Xs...)
gradient(x -> sum(map(inv, x)), [1,2,3.0])
gradient(x -> sum(my_map(inv, x)), [1,2,3.0]) # dX
gradient(x -> sum(map(z -> z/x, 1:3)), 4.0)
gradient(x -> sum(my_map(z -> z/x, 1:3)), 4.0) # df
```

2 Likes