Weights update in Flux

I translated ur description into Flux training code. see if you can follow the code and clear things up

using Flux

using Flux: Dense

input = [0.8, 0.1, 0.3]


model = Chain(
    input->reshape(input, 3, 1, 1, :),
    Dense(3, 3, relu),
    output->reshape(output, 3, :)
)

output = model(input)

fn(output) = transpose(output)*[3; 2; 5]

loss(input) = sum((fn(model(input)) .- 2.5) .^ 2)


opt = ADAM()

p = params(model)


old_loss = loss(input)

Flux.train!(loss, p, [(input,)], opt)

new_loss = loss(input)

println(old_loss)
println(new_loss) # new loss should be better
2 Likes