How to use hook to clip a gradient?

I’m trying to train a neural network using gradient clipping. I’m trying to understand how the hook function works.
Let’s define a model and a clip function

using Zygote, LinearAlgebra
clip(g) = g./norm(g)
a = 2; b = 1; m(x,y) = a*x*(y-a)+b
gradient(m, 2, 3) # returns (2,4)
gradient(m,2,3) |> clip # returns (0.4472135954999579, 0.8944271909999159)

But I’m using Flux’s train function so I need to directly provide a loss function with a clipped gradient. But this does not work :

gradient(2,3) do (x,y)
       Zygote.hook(clip, m(x,y))
end
ERROR: MethodError: no method matching (::var"#31#32")(::Int64, ::Int64)
Stacktrace:
 [1] macro expansion at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context, ::var"#31#32", ::Int64, ::Int64) at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface2.jl:7
 [3] _pullback(::Function, ::Int64, ::Int64) at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface.jl:29
 [4] pullback(::Function, ::Int64, ::Int64) at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface.jl:35
 [5] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface.jl:44
 [6] top-level scope at none:0

I cannot find any documentation on how to clip a gradient. Which is weird because it’s a frequently used tool in DL I think.

1 Like

The error here is actually in your do block; it should be x, y, not (x, y), since as written it’s expecting a tuple of two things rather than two arguments.

However, that then returns (2, 4), which isn’t what you want. The reason is that hook applies to the gradient of the variable passed to it. This is slightly clearer if you write

gradient(2,3) do x, y
  l = m(x, y)
  Zygote.hook(clip, l)
end

clip gets passed , the gradient of the loss, which is always 1, so this doesn’t do anything useful. To apply clip to (x, y) you have to pass that value to hook. Here’s one way to write that:

gradient(2,3) do x, y
  x, y = Zygote.hook(clip, (x, y))
  m(x, y)
end

If you can think of ways to make this stuff clearer in the docs a pr would be huge! This is definitely the kind of area we’d like to be a bit more comprehensive on.

1 Like

Thank you, it’s getting clearer. And how would you proceed to clip the gradient of a loss function? I tried this

using Flux, LinearAlgebra
model = Chain(Dense(2,1))
clip(g) = g./norm(g)
loss(x,y) = mean((x .- y).^2)
batch = (rand(2,2), rand(1,2))

gs = gradient(Flux.params(model)) do
           loss(model(batch[1]), batch[2])
end

gs_clip = gradient(Flux.params(model)) do
           Zygote.hook(clip, loss(model(batch[1]), batch[2]))
end

But gs and gs_clip are equal. I’d happily make a PR but I need to fully understand this first since the typical use-case in ML is to clip the gradient of the loss function.

Maybe it would also be interesting to provide clip as a built-in utility function.

Right, that has the same issue as the original example: you’re applying clip to the gradient of the loss value, not to the parameters.

One way to approach this would be to define a function like gradclip(x) = hook(clip, x) and then do model = mapleaves(gradclip, model) within the forward pass. Definitely agree that we could use some examples and/or utilities for that kind of thing as it’s not obvious.

1 Like

Ha yes indeed, this time it was on purpose, I thought the clipping heuristic was to clip the gradient of the error before backpropagating in the parameters. But it’s the gradient w.r.t. the parameters which is clipped. Thanks for your time, I’ll practice.