I am trying to optimize a NN with an Optim.jl optimizer by using destructure/restructure on the model, and Zygote to obtain the gradient of the loss function. I noticed that Zygote does not seem to be able to calculate the gradient of a RNN correctly. In the following code, the length of the destructured parameter vector p should be the same as the gradient g(p), however the length of the gradient is 0 for a LSTM, GRU and RNN, whereas the lengths are equal for e.g. a Dense layer:
using Flux, Zygote, Statistics
x = rand(Float32, 10,200)
y = rand(Float32, 1,200)
model = LSTM(10,1) # change to Dense, GRU etc.
p,re = Flux.destructure(model)
function loss(p)
m = re(p)
mean(abs2, Flux.stack(m.(Flux.unstack(x,2)),1) .- y)
end
g = x -> Zygote.gradient(loss,x)[1]
@info length(p), length(g(p))
Is this expected behaviour? I’m using Flux v0.11.2 and Zygote v0.5.17.