Help needed with Zygote Flux and custom adjoints on GPU

I am incorporating some simple non-standard elements in a larger DNN.

To achieve this I had to manually define some reverse rules through ChainRulesCore syntax


Which works great on the cpu, but somehow is not portable to the gpu.

I narrowed it to a MWE that fails the GPU compilation of the pullback function

using Zygote
using CUDA

f(A) = map(A) do a

A = rand(Float32,128,128)
gA = gpu(A)
f(A) # works
f(gA) # works

Δ,pb = pullback(f,A) #works

Δ,pb = pullback(f,gA) #fails

Any insights?

I don’t believe map works in reverse on GPU with Zygote. Workarounds are to define an rrule that encapsulates the map, or use broadcasting instead.