I am using Zygote to compute the gradient of a function which calls map(f, x), where applying f to each element of the array x is slow. I would like to speed up the gradient computation by using a threaded map such as ThreadsX.map; however, Zygote cannot differentiate through this.
Thus, I would like to write a custom adjoint for ThreadsX.map that also parallelizes the backwards pass (using threads) but don’t know how to do this and would be grateful for any help.
I found this discussion to be helpful – it is essentially the same problem but for ThreadsX.sum. The solution implemented a custom adjoint for ThreadsX.sum by taking the regular rrule for sum and replacing the map and sum calls with ThreadsX versions. I tried to do something analogous to this but couldn’t find the rrule for map in ChainRules.jl.
Alternatively, is there another package that contains a threaded map that is compatible with Zygote?