Help needed with Zygote Flux and custom adjoints on GPU

I am incorporating some simple non-standard elements in a larger DNN.

To achieve this I had to manually define some reverse rules through ChainRulesCore syntax

ChainRulesCore.rrule(::typeof(f),args...)

Which works great on the cpu, but somehow is not portable to the gpu.

I narrowed it to a MWE that fails the GPU compilation of the pullback function


using Zygote
using CUDA

f(A) = map(A) do a
    a
end

A = rand(Float32,128,128)
gA = gpu(A)
f(A) # works
f(gA) # works

Δ,pb = pullback(f,A) #works

Δ,pb = pullback(f,gA) #fails

Any insights?

I don’t believe map works in reverse on GPU with Zygote. Workarounds are to define an rrule that encapsulates the map, or use broadcasting instead.