I’ve been trying to get a feel for Flux by modeling a simple non-linear function (X^2).
At the end of each epoch, I print the model output for 3 (expecting 9), but it seems to converge to ~34
Can anyone see what I’m doing wrong?
using Flux using Printf model = Chain( Dense(1, 50), Dense(50, 1)) x = collect(-10:.1:10)' y = x.^2 N = length(y) loss(x, y) = Flux.mse(model(x), y) opt = ADAM() epochs = 15 ps = Flux.params(model) @progress for epoch = 1:epochs for i = 1:N gs = Flux.Tracker.gradient(() -> loss(x[:,i], y[i]), ps) Flux.Tracker.update!(opt, Tracker.Params(ps), gs) end @printf "Epoch: %d 3^2 = %1.2f\n" epoch model().data end