Hi everyone,
I am fairly new to Julia and Flux and I am trying to implement a WGAN-variant that I have implemented in Keras before. Full code for the issue can be found here.
I keep getting gradients that consist of pairs of (“SOME_KEY”, “Nothing”) which in the end leads to Flux not updating the model when Flux.optimise.update!
is called.
These are the problematic lines:
...
grads = Flux.gradient(ps) do
loss = wasserstein_g_loss(gen_data)
end;
Flux.update!(opt_gen, ps, grads)
...
I have tried a bunch of variations of calling Flux.gradient
and Flux.pullback
but nothing seems to work for me. Both methods work for me in other contexts, and I am reaching the end of my rope here.
If you read through this post and took a look at the code, I thank you kindly for your time.
Cheers!
Welcome!
Flux/Zygote’s gradient
APIs expect you to run the forward pass of the model inside the callback. In your case, only the output of the generator is actually referenced in the loss, so Zygote never sees the parameter arrays being referenced and can’t generate a gradient for them. If you look at the tutorials, docs or model zoo examples, you’ll see they explicitly reference the all or part of the model in the loss function.
Since you mentioned being familiar with Keras, this is similar to how tf.GradientTape | TensorFlow v2.9.1 works. The upside is that there is no need for stateful .grad
properties or .zero_grad
methods on arrays/“tensors”. The (tiny) downside is that the code that runs in the callback needs to use every parameter you’re interested in differentiating with respect to. The easiest way to do that, of course, is to run the forward pass in said callback.
2 Likes
Thank you for taking your time to educate me, I really appreciate this.
Having to include the forward pass makes perfect sense, and I have indeed tried this in an earlier variation. This is exactly the approach I use with gradient tape in Keras/TF.
function train_gen!(gen, crit, opt_gen, sample_data);
ps = Flux.params(gen)
x_dim, y_dim = size(sample_data)
noise = reshape(rand(x_dim * y_dim), (x_dim, y_dim))
gen_in = hcat(noise, sample_data)
local loss;
loss, grads = Flux.pullback(ps) do
gen_data = gen(gen_in) # Forward pass
gen_norm = gradient_normalization(crit, gen_data)
loss = wasserstein_g_loss(gen_norm)
end;
grads = grads(1f0)
Flux.Optimise.update!(opt_gen, ps, grads)
return loss
end;
Upon examining it, I do in fact get gradients for w.r.t. ps (Hurray!), but the model parameters do not change after the Flux.Optimise.update!
call.
I will keep at it. Thank you again for your time and effort.
You can inspect ps
and grads
to make sure the same params are being included in both (grads
will usually have more, so make sure everything you want optimized from ps
is in there). If visual inspection is too difficult, you can loop over the params like so:
for p in ps
@assert p in gs p
end