Lux flexible architecture initialization plus setting parameters to Float64

I’m switching to Lux for my ML problems, and I need advice on a couple basic things.

  1. How can I initialize my network in a flexible way so that the number of layers and nodes is an argument?
  2. How can I force the network to be Float64?

Here’s my example code, with hardcoded options for number of layers. You’ll see my attempt to make the parameters Float64 has failed:

using Lux
using Random

function build_model(n_in, n_out, n_layers, n_nodes, act_fun=leakyrelu, last_fun=relu)
        first_layer = Lux.Dense(n_in, n_nodes, act_fun)
        last_layer = Lux.Dense(n_nodes => n_out, last_fun)
        if n_layers == 4
            m = Chain(first_layer, Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun),
                Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun), last_layer)
        elseif n_layers == 8
            m = Chain(first_layer, Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun),
                Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun),
                Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun),
                Lux.Dense(n_nodes => n_nodes, act_fun), Lux.Dense(n_nodes => n_nodes, act_fun), last_layer)
        end
        return Lux.f64(m)
end

nlayers = 4
nnodes = 16
m = build_model(4, 3, nlayers, nnodes)
ps, st = Lux.setup(Random.default_rng(), m)
println("Parameter data type: " * string(typeof(ps[1][1])))

I’m not very familiar with Lux but I would think you need to change the type of the parameters and states to Float64 instead of calling it on the Model object. But maybe lux does some magic there to make it work.

In Lux, the model itself doesn’t contain any weights ect so you need to use f64 on the parameters and states instead then everything should go well. As for setup model using some number of layer it’s just typical Julia you can use Chain([Dense(a[i],b[i]) for i in 1:num_layer]...) or create a specific layer (really different from Flux),

function MyLayer(num_layer)
    return @compact(
        D = Tuple( Dense(2^i,2^(i+1)) for i in 0:num_layer-1)
) do x
   h=copy(x)
   for Di in D
       h = Di(h)
    end
      @return h
end
end
1 Like