It’s very easy to convert a Flux model to Float64 using code like so:
model = Chain(Dense(N, 32, tanh),
Dense(32, 16, tanh),
Dense(16, 8, tanh),
Dense(8, 1, softplus)) |> f64
return model
In Lux, I am doing this to get a Float64 equivalent.
When I run the code below, the loss always gets stuck close to 1e-8, leading me to believe that there is precision loss to Float32 at some point.
Can anyone help me figure it out?
using Lux, Random, Optimisers, Zygote, ComponentArrays, Statistics, Printf
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
# Construct the layer
model = Chain(Dense(2 => 32, tanh), Dense(32 => 32, tanh), Dense(32 => 1))
# Get the device determined by Lux
device = cpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device
ps = ps |> ComponentArray .|> Float64
# Dummy Input
x = rand(rng, Float64, 128, 2)' |> device
y = (x[1,:]./(x[1,:] .+ x[2,:]))'
function loss_function(model, ps, st, x, y)
y_pred, st = Lux.apply(model, x, ps, st)
mse_loss = mean(abs2, y_pred .- y)
return mse_loss, st
end
st_opt = Optimisers.setup(Optimisers.Adam(0.0001), ps)
stt = time()
for epoch in 1:100000
# Gradients
gs = gradient(p -> loss_function(model, p, st, x, y)[1], ps)[1]
# Optimization
st_opt, ps = Optimisers.update(st_opt, ps, gs)
if epoch % 10000 == 0
lss = loss_function(model, ps, st, x, y)[1]
@printf("Training Progress Epoch : %.4d Loss : %.2e Elapsed time: %.2f mins\n",
epoch, lss, (time() - stt)./60.0)
end
end