Mutate Zygote Gradients with a Custom Mask Before Update?

There’s no need to do that for the code you have above, because the only place gradients are “tracked” is inside the callback to pullback. In other words, there’s no such thing as a Zygote.gradient object and gs is just a collection of arrays :slight_smile:.

With that in mind, the error has to be occurring within the call stack of loss somewhere. If you share a complete MWE with the loss and all the auxiliary training code, we can troubleshoot that part.