Float64 Lux model (Loss stuck at 1e-8)

It’s very easy to convert a Flux model to Float64 using code like so:

``````    model = Chain(Dense(N,  32, tanh),
Dense(32, 16, tanh),
Dense(16, 8, tanh),
Dense(8, 1, softplus)) |> f64
return model
``````

In Lux, I am doing this to get a Float64 equivalent.
When I run the code below, the loss always gets stuck close to 1e-8, leading me to believe that there is precision loss to Float32 at some point.
Can anyone help me figure it out?

``````using Lux, Random, Optimisers, Zygote, ComponentArrays, Statistics, Printf

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Construct the layer
model = Chain(Dense(2 => 32, tanh), Dense(32 => 32, tanh), Dense(32 => 1))

# Get the device determined by Lux
device = cpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device
ps = ps |> ComponentArray .|> Float64

# Dummy Input
x = rand(rng, Float64, 128, 2)' |> device

y =  (x[1,:]./(x[1,:] .+  x[2,:]))'

function loss_function(model, ps, st, x, y)
y_pred, st = Lux.apply(model, x, ps, st)
mse_loss = mean(abs2, y_pred .- y)
return mse_loss, st
end

stt = time()
for epoch in 1:100000
gs = gradient(p -> loss_function(model, p, st, x, y)[1], ps)[1]
# Optimization
st_opt, ps = Optimisers.update(st_opt, ps, gs)
if epoch % 10000 == 0
lss = loss_function(model, ps, st, x, y)[1]
@printf("Training Progress  Epoch : %.4d    Loss : %.2e    Elapsed time: %.2f mins\n",
epoch, lss, (time() - stt)./60.0)
end
end
``````

Solved.
All I had to do was incorporate a more adaptive learning rate adjuster and train for much longer on a very small dataset to overfit.
Now the loss goes all the way down to machine zero.

Something along these lines.

``````            lss = loss_function(model, ps, st, x, y)[1]
if min_loss > lss
lr_ = lr_*0.9