Does anyone have a working implementation of gradient clipping in Flux? I’ve tried directly modifying the param.grad field but get LoadError: type TrackedArray is immutable
1 Like
Can’t you define a spacial layer, whose main goal would be just to clip the gradient? Something along the following line.
clip(x) = x
∇clip(Δ) = min.(max.(Δ, lower_bound), upperbound)
Flux.Tracker.@grad function clip(x)
clip(Flux.data(x)), Δ -> (∇clip(Δ),)
end
I think that if you register this, it should achieve your goal.
1 Like
Flux provides hook
for this. hook(x, clip)
will apply clip
to the gradient of x
, and within that you can do whatever you want.
2 Likes