Flux gradient clipping

Does anyone have a working implementation of gradient clipping in Flux? I’ve tried directly modifying the param.grad field but get LoadError: type TrackedArray is immutable

1 Like

Can’t you define a spacial layer, whose main goal would be just to clip the gradient? Something along the following line.

clip(x) = x
∇clip(Δ) = min.(max.(Δ, lower_bound), upperbound)

Flux.Tracker.@grad function clip(x)
        clip(Flux.data(x)), Δ -> (∇clip(Δ),)
    end

I think that if you register this, it should achieve your goal.

1 Like

Flux provides hook for this. hook(x, clip) will apply clip to the gradient of x, and within that you can do whatever you want.

2 Likes