Flux gradient clipping



Does anyone have a working implementation of gradient clipping in Flux? I’ve tried directly modifying the param.grad field but get LoadError: type TrackedArray is immutable


Can’t you define a spacial layer, whose main goal would be just to clip the gradient? Something along the following line.

clip(x) = x
∇clip(Δ) = min.(max.(Δ, lower_bound), upperbound)

Flux.Tracker.@grad function clip(x)
        clip(Flux.data(x)), Δ -> (∇clip(Δ),)

I think that if you register this, it should achieve your goal.


Flux provides hook for this. hook(x, clip) will apply clip to the gradient of x, and within that you can do whatever you want.