I believe the only part of the Distributed API Zygote supports is pmap
, and even then only in a limited capacity. If you can’t make do with that, then consider writing custom rules for this or figuring out a way to move the distributed part outside of the (with)gradient call.