Initialize Lux.jl NN parameters according to Lux.glorot_normal

Hi all,

I would like to initialize the parameters of a Lux.jl neural network with something other than a random distribution; for instance with Lux.glorot_normal. I am using this network with NeuralPDE.jl with the following architecture:

dim = length(domains) 
activation = Lux.σ
nnodes = 10
chain = Lux.Chain(Lux.Dense(dim, nnodes, activation), 
            Lux.Dense(nnodes, 1)) 

ps = Lux.setup(Random.default_rng(), 
           chain)[1] |> Lux.ComponentArray .|> Float32

discretization = PhysicsInformedNN(chains, strategy, init_params=ps)
@time prob = discretize(pde_sys, discretization)

I’ve been looking through the documentation for Lux and NeuralPDE and can’t find anything to help. It seems like passing an argument to Lux.setup would solve the problem but I’m not sure what this argument would be.

Maybe this is not super desirable to include, but my network performance seems to be highly dependent on initial parameter choice.

Any help is very much appreciated!

You need to specify the initialization for each layer. See Layers - Lux.jl. Including it via setup is not possible since what is being initialized is highly layer dependent.