Flux.gradient returns dict of param and Nothing

Hi everyone,
I am fairly new to Julia and Flux and I am trying to implement a WGAN-variant that I have implemented in Keras before. Full code for the issue can be found here.

I keep getting gradients that consist of pairs of (“SOME_KEY”, “Nothing”) which in the end leads to Flux not updating the model when Flux.optimise.update! is called.
These are the problematic lines:

... 
grads = Flux.gradient(ps) do
        loss = wasserstein_g_loss(gen_data)
end;
Flux.update!(opt_gen, ps, grads)
...

I have tried a bunch of variations of calling Flux.gradient and Flux.pullback but nothing seems to work for me. Both methods work for me in other contexts, and I am reaching the end of my rope here.

If you read through this post and took a look at the code, I thank you kindly for your time.

Cheers!

Welcome!

Flux/Zygote’s gradient APIs expect you to run the forward pass of the model inside the callback. In your case, only the output of the generator is actually referenced in the loss, so Zygote never sees the parameter arrays being referenced and can’t generate a gradient for them. If you look at the tutorials, docs or model zoo examples, you’ll see they explicitly reference the all or part of the model in the loss function.

Since you mentioned being familiar with Keras, this is similar to how tf.GradientTape  |  TensorFlow v2.9.1 works. The upside is that there is no need for stateful .grad properties or .zero_grad methods on arrays/“tensors”. The (tiny) downside is that the code that runs in the callback needs to use every parameter you’re interested in differentiating with respect to. The easiest way to do that, of course, is to run the forward pass in said callback.

2 Likes

Thank you for taking your time to educate me, I really appreciate this.

Having to include the forward pass makes perfect sense, and I have indeed tried this in an earlier variation. This is exactly the approach I use with gradient tape in Keras/TF.

function train_gen!(gen, crit,  opt_gen, sample_data);
    ps = Flux.params(gen)
    x_dim, y_dim = size(sample_data)
    noise = reshape(rand(x_dim * y_dim), (x_dim, y_dim))
    gen_in = hcat(noise, sample_data)
    local loss; 
    loss, grads = Flux.pullback(ps) do
        gen_data = gen(gen_in) # Forward pass
        gen_norm = gradient_normalization(crit, gen_data)
        loss = wasserstein_g_loss(gen_norm)
    end;
    grads = grads(1f0)
    Flux.Optimise.update!(opt_gen, ps, grads)
    return loss
end; 

Upon examining it, I do in fact get gradients for w.r.t. ps (Hurray!), but the model parameters do not change after the Flux.Optimise.update! call.

I will keep at it. Thank you again for your time and effort.

You can inspect ps and grads to make sure the same params are being included in both (grads will usually have more, so make sure everything you want optimized from ps is in there). If visual inspection is too difficult, you can loop over the params like so:

  for p in ps
    @assert p in gs p
  end