sorry, in your case model is pi
, sigma
and mu
→
# lowest-level?
data = [(y, x)]
for epoch in 1:n_epochs
# forward
l = 0f0
# backward
gs = gradient(pars) do
pi_out = pi(y)
sigma_out = sigma(y)
mu_out = mu(y)
l = mdn_loss(pi_out, sigma_out, mu_out, x)
end
Flux.update!(opt, pars, gs)
if epoch % 1000 == 0
println("Epoch: ", epoch, " loss: ", l)
end
end
Epoch: 1000 loss: 2.901283803535177
Epoch: 2000 loss: 2.683036585344914
Epoch: 3000 loss: 2.496097359141723
Epoch: 4000 loss: 1.8825208240140192
Epoch: 5000 loss: 1.5377693999566997
Epoch: 6000 loss: 1.4791544586544207
Epoch: 7000 loss: 1.4550877233830257
Epoch: 8000 loss: 1.4337247960818376
1 Like
But the final graph is still not as expected:
But I’m able to reproduce plots .
I will try to clean-up notebook that i have created and share it once done !!
I have placed my code in this gist
P.S: I’m not using Distributions.jl
package, if u want u can use already available functions instead of one that i have created…!!
1 Like