How to implement the LayerNorm in NeuralPDE by Lux.jl

I tried to implement LayerNorm from Lux.jl package into NeuralPDE following a code:

chain =[Lux.Chain(Lux.LayerNorm(1,Lux.relu),Dense(1,10,Lux.tanh),Dense(10,20,Lux.tanh),Dense(20,10,Lux.tanh),Dense(10,1)) for _ in 1:12]

Then got a message:

MethodError: no method matching LayerNorm(::Int64, ::typeof(NNlib.relu))

The syntax of LayerNorm is:

LayerNorm(shape::NTuple{N, Int}, activation=identity; epsilon=1f-5, dims=Colon(),
          affine::Bool=true, init_bias=zeros32, init_scale=ones32,)

But when I implemented the following:

chain =[Lux.Chain(Lux.LayerNorm((1,1),Lux.relu),Dense(1,10,Lux.tanh),Dense(10,20,Lux.tanh),Dense(20,10,Lux.tanh),Dense(10,1)) for _ in 1:12]

The message is:

MethodError: no method matching layernorm(::Matrix{Float64}, ::Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, ::Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}; dims::Colon, epsilon::Float32)

Please explant to me the right way to implement it.

It should be

chain =[Lux.Chain(Lux.LayerNorm((1,),Lux.relu),Dense(1,10,Lux.tanh),Dense(10,20,Lux.tanh),Dense(20,10,Lux.tanh),Dense(10,1)) for _ in 1:12]
1 Like

Thank you for your reply.