Flux.gradient returns dict of param and Nothing

Welcome!

Flux/Zygote’s gradient APIs expect you to run the forward pass of the model inside the callback. In your case, only the output of the generator is actually referenced in the loss, so Zygote never sees the parameter arrays being referenced and can’t generate a gradient for them. If you look at the tutorials, docs or model zoo examples, you’ll see they explicitly reference the all or part of the model in the loss function.

Since you mentioned being familiar with Keras, this is similar to how tf.GradientTape  |  TensorFlow v2.9.1 works. The upside is that there is no need for stateful .grad properties or .zero_grad methods on arrays/“tensors”. The (tiny) downside is that the code that runs in the callback needs to use every parameter you’re interested in differentiating with respect to. The easiest way to do that, of course, is to run the forward pass in said callback.

2 Likes