I’m switching to Lux for my ML problems, and I need advice on a couple basic things.
- How can I initialize my network in a flexible way so that the number of layers and nodes is an argument?
- How can I force the network to be Float64?
Here’s my example code, with hardcoded options for number of layers. You’ll see my attempt to make the parameters Float64 has failed:
using Lux
using Random
function build_model(n_in, n_out, n_layers, n_nodes, act_fun=leakyrelu, last_fun=relu)
first_layer = Lux.Dense(n_in, n_nodes, act_fun)
last_layer = Lux.Dense(n_nodes => n_out, last_fun)
if n_layers == 4
m = Chain(first_layer, Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun),
Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun), last_layer)
elseif n_layers == 8
m = Chain(first_layer, Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun),
Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun),
Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun),
Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun), last_layer)
end
return Lux.f64(m)
end
nlayers = 4
nnodes = 16
m = build_model(4, 3, nlayers, nnodes)
ps, st = Lux.setup(Random.default_rng(), m)
println("Parameter data type: " * string(typeof(ps[1][1])))