Need help with example; Mixture Density Networks from Site

sorry, in your case model is pi, sigma and mu

# lowest-level?
data = [(y, x)]

for epoch in 1:n_epochs
    
    # forward
    l = 0f0
    
    # backward
    gs = gradient(pars) do
        pi_out = pi(y)
        sigma_out = sigma(y)
        mu_out = mu(y)
        l = mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    Flux.update!(opt, pars, gs)

    if epoch % 1000 == 0
        println("Epoch: ", epoch, " loss: ", l)
    end
end

:+1:

Epoch: 1000 loss: 2.901283803535177
Epoch: 2000 loss: 2.683036585344914
Epoch: 3000 loss: 2.496097359141723
Epoch: 4000 loss: 1.8825208240140192
Epoch: 5000 loss: 1.5377693999566997
Epoch: 6000 loss: 1.4791544586544207
Epoch: 7000 loss: 1.4550877233830257
Epoch: 8000 loss: 1.4337247960818376
1 Like

But the final graph is still not as expected:
graph

But I’m able to reproduce plots :thinking:.
I will try to clean-up notebook that i have created and share it once done !!

I have placed my code in this gist

P.S: I’m not using Distributions.jl package, if u want u can use already available functions instead of one that i have created…!!

1 Like

1 Like