There’s no need to do that for the code you have above, because the only place gradients are “tracked” is inside the callback to pullback
. In other words, there’s no such thing as a Zygote.gradient
object and gs
is just a collection of arrays .
With that in mind, the error has to be occurring within the call stack of loss
somewhere. If you share a complete MWE with the loss and all the auxiliary training code, we can troubleshoot that part.