I’m trying to train a neural network using gradient clipping. I’m trying to understand how the hook function works.
Let’s define a model and a clip function
using Zygote, LinearAlgebra
clip(g) = g./norm(g)
a = 2; b = 1; m(x,y) = a*x*(y-a)+b
gradient(m, 2, 3) # returns (2,4)
gradient(m,2,3) |> clip # returns (0.4472135954999579, 0.8944271909999159)
But I’m using Flux’s train function so I need to directly provide a loss function with a clipped gradient. But this does not work :
gradient(2,3) do (x,y)
Zygote.hook(clip, m(x,y))
end
ERROR: MethodError: no method matching (::var"#31#32")(::Int64, ::Int64)
Stacktrace:
[1] macro expansion at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context, ::var"#31#32", ::Int64, ::Int64) at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface2.jl:7
[3] _pullback(::Function, ::Int64, ::Int64) at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface.jl:29
[4] pullback(::Function, ::Int64, ::Int64) at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface.jl:35
[5] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at C:\Users\Henri\.julia\packages\Zygote\ApBXe\src\compiler\interface.jl:44
[6] top-level scope at none:0
I cannot find any documentation on how to clip a gradient. Which is weird because it’s a frequently used tool in DL I think.
1 Like
The error here is actually in your do
block; it should be x, y
, not (x, y)
, since as written it’s expecting a tuple of two things rather than two arguments.
However, that then returns (2, 4)
, which isn’t what you want. The reason is that hook
applies to the gradient of the variable passed to it. This is slightly clearer if you write
gradient(2,3) do x, y
l = m(x, y)
Zygote.hook(clip, l)
end
clip
gets passed l̄
, the gradient of the loss, which is always 1
, so this doesn’t do anything useful. To apply clip
to (x, y)
you have to pass that value to hook
. Here’s one way to write that:
gradient(2,3) do x, y
x, y = Zygote.hook(clip, (x, y))
m(x, y)
end
If you can think of ways to make this stuff clearer in the docs a pr would be huge! This is definitely the kind of area we’d like to be a bit more comprehensive on.
1 Like
Thank you, it’s getting clearer. And how would you proceed to clip the gradient of a loss function? I tried this
using Flux, LinearAlgebra
model = Chain(Dense(2,1))
clip(g) = g./norm(g)
loss(x,y) = mean((x .- y).^2)
batch = (rand(2,2), rand(1,2))
gs = gradient(Flux.params(model)) do
loss(model(batch[1]), batch[2])
end
gs_clip = gradient(Flux.params(model)) do
Zygote.hook(clip, loss(model(batch[1]), batch[2]))
end
But gs and gs_clip are equal. I’d happily make a PR but I need to fully understand this first since the typical use-case in ML is to clip the gradient of the loss function.
Maybe it would also be interesting to provide clip as a built-in utility function.
Right, that has the same issue as the original example: you’re applying clip
to the gradient of the loss value, not to the parameters.
One way to approach this would be to define a function like gradclip(x) = hook(clip, x)
and then do model = mapleaves(gradclip, model)
within the forward pass. Definitely agree that we could use some examples and/or utilities for that kind of thing as it’s not obvious.
1 Like
Ha yes indeed, this time it was on purpose, I thought the clipping heuristic was to clip the gradient of the error before backpropagating in the parameters. But it’s the gradient w.r.t. the parameters which is clipped. Thanks for your time, I’ll practice.