Thank you for taking your time to educate me, I really appreciate this.
Having to include the forward pass makes perfect sense, and I have indeed tried this in an earlier variation. This is exactly the approach I use with gradient tape in Keras/TF.
function train_gen!(gen, crit, opt_gen, sample_data);
ps = Flux.params(gen)
x_dim, y_dim = size(sample_data)
noise = reshape(rand(x_dim * y_dim), (x_dim, y_dim))
gen_in = hcat(noise, sample_data)
local loss;
loss, grads = Flux.pullback(ps) do
gen_data = gen(gen_in) # Forward pass
gen_norm = gradient_normalization(crit, gen_data)
loss = wasserstein_g_loss(gen_norm)
end;
grads = grads(1f0)
Flux.Optimise.update!(opt_gen, ps, grads)
return loss
end;
Upon examining it, I do in fact get gradients for w.r.t. ps (Hurray!), but the model parameters do not change after the Flux.Optimise.update! call.
I will keep at it. Thank you again for your time and effort.