How compute `gradient.(f, w)` on GPU?

See This custom Zygote.jl adjoint is not giving me the speed up I expected and how to migrate to GPU?

Can’t quite get this to work on th GPU unless I differentiate the function myself.