Float64 Lux model (Loss stuck at 1e-8)

It’s very easy to convert a Flux model to Float64 using code like so:

    model = Chain(Dense(N,  32, tanh),
                  Dense(32, 16, tanh),  
                  Dense(16, 8, tanh),
                  Dense(8, 1, softplus)) |> f64  
    return model

In Lux, I am doing this to get a Float64 equivalent.
When I run the code below, the loss always gets stuck close to 1e-8, leading me to believe that there is precision loss to Float32 at some point.
Can anyone help me figure it out?

using Lux, Random, Optimisers, Zygote, ComponentArrays, Statistics, Printf

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Construct the layer
model = Chain(Dense(2 => 32, tanh), Dense(32 => 32, tanh), Dense(32 => 1))  

# Get the device determined by Lux
device = cpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device
ps = ps |> ComponentArray .|> Float64

# Dummy Input
x = rand(rng, Float64, 128, 2)' |> device

y =  (x[1,:]./(x[1,:] .+  x[2,:]))'

function loss_function(model, ps, st, x, y)
    y_pred, st = Lux.apply(model, x, ps, st) 
    mse_loss = mean(abs2, y_pred .- y) 
    return mse_loss, st
end

st_opt = Optimisers.setup(Optimisers.Adam(0.0001), ps)
stt = time()
for epoch in 1:100000 
    # Gradients
    gs = gradient(p -> loss_function(model, p, st, x, y)[1], ps)[1] 
    # Optimization 
    st_opt, ps = Optimisers.update(st_opt, ps, gs)
    if epoch % 10000 == 0 
        lss = loss_function(model, ps, st, x, y)[1]    
        @printf("Training Progress  Epoch : %.4d    Loss : %.2e    Elapsed time: %.2f mins\n", 
                epoch, lss, (time() - stt)./60.0)  
    end   
end 

Solved.
All I had to do was incorporate a more adaptive learning rate adjuster and train for much longer on a very small dataset to overfit.
Now the loss goes all the way down to machine zero.

Something along these lines.

            lss = loss_function(model, ps, st, x, y)[1]    
            if min_loss > lss 
                lr_ = lr_*0.9
                Optimisers.adjust!(st_opt, lr_)
            else
                lr_ = lr_*1.01
                if lr_ > 1e-3
                    lr_ = 1e-3
                    Optimisers.adjust!(st_opt, lr_)
                else
                    Optimisers.adjust!(st_opt, lr_)
                end 
            end 
            min_loss = lss
1 Like