Hey @HerAdri
I finally could able to solve something !!, thanks to @ToucheSir where he mentioned in other post about how gradients are calculated in Zygote
here.
Taking that comment into consideration, i moved every calculations that needs to be tracked for backward pass into gradient do block
so i changed the code as follows and it looks like for me it does the job (acc to me ) you should confirm if otherwise →
# lowest-level?
data = [(y, x)]
for epoch in 1:n_epochs
# forward
l = 0f0
# backward
gs = gradient(pars) do
pi_out = model[:pi](y)
sigma_out = model[:sigma](y)
mu_out = model[:mu](y)
l = mdn_loss(pi_out, sigma_out, mu_out, x)
end
Flux.update!(opt, pars, gs)
if epoch % 1000 == 0
println("Epoch: ", epoch, " loss: ", l)
end
end