Right, that has the same issue as the original example: you’re applying clip
to the gradient of the loss value, not to the parameters.
One way to approach this would be to define a function like gradclip(x) = hook(clip, x)
and then do model = mapleaves(gradclip, model)
within the forward pass. Definitely agree that we could use some examples and/or utilities for that kind of thing as it’s not obvious.