I am trying to optimize a NN with an Optim.jl
optimizer by using destructure
/restructure
on the model, and Zygote
to obtain the gradient of the loss function. I noticed that Zygote
does not seem to be able to calculate the gradient of a RNN correctly. In the following code, the length of the destructured parameter vector p
should be the same as the gradient g(p)
, however the length of the gradient is 0 for a LSTM
, GRU
and RNN
, whereas the lengths are equal for e.g. a Dense
layer:
using Flux, Zygote, Statistics
x = rand(Float32, 10,200)
y = rand(Float32, 1,200)
model = LSTM(10,1) # change to Dense, GRU etc.
p,re = Flux.destructure(model)
function loss(p)
m = re(p)
mean(abs2, Flux.stack(m.(Flux.unstack(x,2)),1) .- y)
end
g = x -> Zygote.gradient(loss,x)[1]
@info length(p), length(g(p))
Is this expected behaviour? I’m using Flux v0.11.2 and Zygote v0.5.17.