How to use hook to clip a gradient?

Right, that has the same issue as the original example: you’re applying clip to the gradient of the loss value, not to the parameters.

One way to approach this would be to define a function like gradclip(x) = hook(clip, x) and then do model = mapleaves(gradclip, model) within the forward pass. Definitely agree that we could use some examples and/or utilities for that kind of thing as it’s not obvious.

1 Like