Need help with example; Mixture Density Networks from Site

Tried to reproduce the example that you can see on the site:
Mixture Density Networks
But I don’t know how to update the following lines of code:

# lowest-level?
data = [(y, x)]

for epoch = 1:n_epochs
    
    # forward
    pi_out = pi(y)
    sigma_out = sigma(y)
    mu_out = mu(y)
    l = mdn_loss(pi_out, sigma_out, mu_out, x)
    
    # backward
    Tracker.back!(l)
    for p in pars
        Tracker.update!(opt, p, Tracker.grad(p))
    end

    if epoch % 1000 == 0
        println("Epoch: ", epoch, " loss: ", l)
    end
end

In order to achieve the final result:
Flux_ example_to_update

I work with

(@v1.7) pkg> st Flux
      Status `C:\Users\Hermesr\.julia\environments\v1.7\Project.toml`
  [587475ba] Flux v0.13.0

Hi @HerAdri

I guess you can change the backward pass
from →

    l = mdn_loss(pi_out, sigma_out, mu_out, x)
    
    # backward
    Tracker.back!(l)
    for p in pars
        Tracker.update!(opt, p, Tracker.grad(p))
    end

to →

    # backward
    gs = gradient(pars) do
        l = mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    update!(opt, pars, gs)

This should work-out in Flux : v0.13.0, i never checked-it myself though…

The lines of code are as follows (according to the proposal, ok?):

for epoch = 1:n_epochs
    
    # forward
    pi_out = pi(y)
    sigma_out = sigma(y)
    mu_out = mu(y)
    
    
    # backward
    gs = gradient(pars) do
        l = mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    update!(opt, pars, gs)

    if epoch % 1000 == 0
        println("Epoch: ", epoch, " loss: ", l)
    end
end

and we get the following alert

ERROR: UndefVarError: l not defined
Stacktrace:
 [1] top-level scope
   @ c:\projects\Julia Flux\Mixture Density Networks with Julia.jl:162

ooh okay,
Maybe we don’t need to assign loss value to variable l, can you try removing l like follows →

    # backward
    gs = gradient(pars) do
        mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    update!(opt, pars, gs)

It didn’t work, I get the same error

The problem is not with this line sorry →

    l = mdn_loss(pi_out, sigma_out, mu_out, x)

After do block completes its execution, we are loosing the var l, so it is undefined at line →

println("Epoch: ", epoch, " loss: ", l)

You can make use of Flux.withgradient function here to access both loss value and gradients
below code fit’s our need →

    # backward
    l, gs = Flux.withgradient(pars) do
        mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    update!(opt, pars, gs)

And also defining l = 0f0 outside the do block works fine with our initial solution (function gradient(....))

No more error inside the loop.
But the result is not as expected
Flux_ eerror

you illustrate me:
“And also defining l = 0f0 outside the do block works fine with our initial solution (function gradient(…))”

Are the loss values what you are getting comparable to what is shown in the blog ?
Unfortunately i don’t know anything about Mixture Density Networks

1 Like

For example like below →

for epoch = 1:n_epochs
    
    # forward
    pi_out = pi(y)
    sigma_out = sigma(y)
    mu_out = mu(y)
    
    # backward
    l = 0f0   # define l before entering into do block to access later
    gs = gradient(pars) do
        l = mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    update!(opt, pars, gs)

    if epoch % 1000 == 0
        println("Epoch: ", epoch, " loss: ", l)  # we can access l here without any undefined error
    end
end
ERROR: UndefVarError: gradient not defined
Stacktrace:
 [1] top-level scope
   @ c:\projects\Julia Flux\Mixture Density Networks with Julia.jl:159

Epoch: 1000 loss: 0.0
Epoch: 2000 loss: 0.0
Epoch: 3000 loss: 0.0
Epoch: 4000 loss: 0.0
Epoch: 5000 loss: 0.0
Epoch: 6000 loss: 0.0
Epoch: 7000 loss: 0.0
Epoch: 8000 loss: 0.0

is this the new error ? :face_with_monocle:
can you also paste the code, so that we can see what exactly we have @line-no-159

It would be better to summarize our 2 solutions so that it is not confusing …

  1. Defining l outside of do block and using gradient / Flux.gradient function →
for epoch = 1:n_epochs
    
    # forward
    pi_out = pi(y)
    sigma_out = sigma(y)
    mu_out = mu(y)
    
    # backward
    l = 0f0   # define l before entering into do block to access later
    gs = gradient(pars) do   # or gs = Flux.gradient(pars) do
        l = mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    update!(opt, pars, gs)

    if epoch % 1000 == 0
        println("Epoch: ", epoch, " loss: ", l)  # we can access l here without any undefined error
    end
end
  1. using Flux.withgradient function →
for epoch = 1:n_epochs
    
    # forward
    pi_out = pi(y)
    sigma_out = sigma(y)
    mu_out = mu(y)
    
    # backward
    l, gs = Flux.withgradient(pars) do
        mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    update!(opt, pars, gs)

    if epoch % 1000 == 0
        println("Epoch: ", epoch, " loss: ", l) 
    end
end

this is the code:

#https://www.janisklaise.com/post/mdn_julia/
using Distributions
using Flux
using Plots
using Random
Random.seed!(12345); # for reproducibility

function generate_data(n_samples)
    ϵ = rand(Normal(), 1, n_samples)
    x = rand(Uniform(-10.5, 10.5), 1, n_samples)
    y = 7sin.(0.75x) + 0.5x + ϵ
    return x, y
end

n_samples = 1000
x, y = generate_data(n_samples); # semicolon to suppress output as in MATLAB
scatter(transpose(y), transpose(x), alpha=0.2)

n_gaussians = 5
n_hidden = 20;

z_h = Dense(1, n_hidden, tanh)
z_π = Dense(n_hidden, n_gaussians)
z_σ = Dense(n_hidden, n_gaussians, exp)
z_μ = Dense(n_hidden, n_gaussians);

pi = Chain(z_h, z_π, softmax)
sigma = Chain(z_h, z_σ)
mu = Chain(z_h, z_μ);

#We need to implement this loss function ourselves:

function gaussian_distribution(y, μ, σ)
    # periods are used for element-wise operations
    result = 1 ./ ((sqrt(2π).*σ)).*exp.(-0.5((y .- μ)./σ).^2)
end;
function mdn_loss(π, σ, μ, y)
    result = π.*gaussian_distribution(y, μ, σ)
    result = sum(result, dims=1)
    result = -log.(result)
    return mean(result)
end;

pars = Flux.params(pi, sigma, mu)
opt = ADAM()
n_epochs = 8000;

#Finally we write the training loop.
# lowest-level?
data = [(y, x)]

#1 Defining l outside of do block and using gradient / Flux.gradient function 
for epoch = 1:n_epochs
    
    # forward
    pi_out = pi(y)
    sigma_out = sigma(y)
    mu_out = mu(y)
    
    
    # backward
    l = 0f0   # define l before entering into do block to access later
    gs = Flux.gradient(pars) do
        l = mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    Flux.update!(opt, pars, gs)

    if epoch % 1000 == 0
        println("Epoch: ", epoch, " loss: ", l)  # we can access l here without any undefined error
    end
end

x_test = range(-15, stop=15, length=n_samples)
pi_data = pi(transpose(collect(x_test)))
sigma_data = sigma(transpose(collect(x_test)))
mu_data = mu(transpose(collect(x_test)));
plot(collect(x_test), transpose(pi_data), lw=2)

plot(collect(x_test), transpose(sigma_data), lw=2)

plot(collect(x_test), transpose(mu_data), lw=2)

#We can also plot the mean μₖ(x) for each Gaussian together with the range  μₖ(x) ± σₖ(x):
plot(collect(x_test), transpose(mu_data), lw=2, ribbon=transpose(sigma_data),
    ylim=(-12,12), fillalpha=0.3)
scatter!(transpose(y), transpose(x), alpha=0.05)

function gumbel_sample(x)
    z = rand(Gumbel(), size(x))
    return argmax(log.(x) + z, dims=1)
end

k = gumbel_sample(pi_data);

sampled = rand(Normal(),1, 1000).*sigma_data[k] + mu_data[k];

scatter(transpose(y), transpose(x), alpha=0.2)
scatter!(collect(x_test), transpose(sampled), alpha=0.5)

and with it the loss function does not evolve.:

Epoch: 1000 loss: 8.61979609988576
Epoch: 2000 loss: 8.61979609988576
Epoch: 3000 loss: 8.61979609988576
Epoch: 4000 loss: 8.61979609988576
Epoch: 5000 loss: 8.61979609988576
Epoch: 6000 loss: 8.61979609988576
Epoch: 7000 loss: 8.61979609988576
Epoch: 8000 loss: 8.61979609988576
#with version:
#using Flux.withgradient function →
for epoch = 1:n_epochs
    
     # forward
     pi_out = pi(y)
     sigma_out = sigma(y)
     mu_out = mu(y)
    
     # backward
     l, gs = Flux.withgradient(pars) do
         mdn_loss(pi_out, sigma_out, mu_out, x)
     end
     update!(opt, pars, gs)

     if epoch % 1000 == 0
         println("Epoch: ", epoch, " loss: ", l)
     end
end

The same thing happens here, the loss function does not evolve.

Epoch: 1000 loss: 4.8080576028606465
Epoch: 2000 loss: 4.8080576028606465
Epoch: 3000 loss: 4.8080576028606465
Epoch: 4000 loss: 4.8080576028606465
Epoch: 5000 loss: 4.8080576028606465
Epoch: 6000 loss: 4.8080576028606465
Epoch: 7000 loss: 4.8080576028606465
Epoch: 8000 loss: 4.8080576028606465

It’s weird :innocent:
I will try to play with the code and let you know if i find any success with it !!

1 Like

Sorry, i couldn’t able to figure out what the issue is, loss seems to be constant (w/o any improvement)
maybe you could open an issue on GitHub jklaise / personal_website so that the author of that blog might help in correcting the code !!
or I hope someone here who knows about flux better would come up with the solution…!!

Hey @HerAdri

I finally could able to solve something !!, thanks to @ToucheSir :grinning:where he mentioned in other post about how gradients are calculated in Zygote here.
Taking that comment into consideration, i moved every calculations that needs to be tracked for backward pass into gradient do block
so i changed the code as follows and it looks like for me it does the job (acc to me :wink:) you should confirm if otherwise →

# lowest-level?
data = [(y, x)]

for epoch in 1:n_epochs
    
    # forward
    l = 0f0
    
    # backward
    gs = gradient(pars) do
        pi_out = model[:pi](y)
        sigma_out = model[:sigma](y)
        mu_out = model[:mu](y)
        l = mdn_loss(pi_out, sigma_out, mu_out, x)
    end
    Flux.update!(opt, pars, gs)

    if epoch % 1000 == 0
        println("Epoch: ", epoch, " loss: ", l)
    end
end
ERROR: UndefVarError: model not defined
Stacktrace:
 [1] _pullback(::Zygote.Context, ::var"#9#10")
   @ Zygote C:\Users\Hermesr\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:9
 [2] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
   @ Zygote C:\Users\Hermesr\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:352
 [3] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
   @ Zygote C:\Users\Hermesr\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:75
 [4] top-level scope
   @ c:\projects\Julia Flux\Mixture Density Networks.jl:59