Clipping gradients with Zygote/Flux

I am currently working with Flux and due to a high amount of stochasticity in my data, I often receive
“Loss is NaN” and “Loss is infinite” errors during training. Reducing the stepsize of my gradient updates of course helps, but slows down training unnecessarily.

I would like to avoid this issue by clipping gradients, but I have not really found a way of doing it. Basically, what I want to do is to clip the gradient by its L^2 norm, so that if the gradient has a norm greater than 1 (or some other constant), I would want to divide it by its norm.

After searching around for a bit, I found the hook function in the Zygote package, which theoretically should be able to do this.
Here is the gradient descent function I currently have:

function pupdate!(S, A, δ, model, α, γ, t)
    function loss(x) log(model(x)[A]) end
    local ps = Flux.params(model)
    local gs = Zygote.gradient(() -> loss(S), ps)
    #@info "neural network before: $(model(S)[A])"
    for p in ps
        Flux.Tracker.update!(p,  α * (γ^t)*δ.* gs[p])
    #@info "neural network after: $(model(S)[A])"

I thought that switching out the gradient line by:

Zygote.gradient(() -> Zygote.hook(Zygote.hook(clipper,loss(S)),ps)


 function clipper(x)
     if norm(x) > 1
         return x./norm(x)
         return x 

should do the trick, but unfortunately this does not work.
Any help would be appreciated!

Is BatchNorm something that you are looking for?

Correct me, if I am wrong but I’m not sure if BatchNormalization helps me here. I need to pass single arrays into my networks frequently, so a BatchNorm layer would just return 0 for single inputs.