Create a layer in Flux

What are the best practices for creating a custom layer in Flux?
I noticed that, for example, creating a layer with internal state is not trivial.
What should I do in general if I want to create a layer?

Not sure what “internal state” means here, but if it means what I think it means doing this should be pretty trivial. Have you seen Advanced Model Building · Flux?

For example, a recurrent layer has a state, how should I build a recurrent layer?
I know that Zygote does not allow to change variables, the state here, and it is preferable not to use global variables in julia language. Any recommendations?

According to the documentation recurrence is straight-forward in Flux. Basically, you define a (stateless) layer taking two inputs h, x and producing two outputs h', y, i.e., the state h is explicitly passed in and returned. Then, Flux.Recur can be used to implicitly pass the state, i.e., turns your layer into a stateful layer.

2 Likes

What do you think of this code?

# custom rnn layer
using Flux, Functors

struct RecModel
    W_h::Matrix{Float32}
    W_i::Matrix{Float32}
    b_h::Vector{Float32}
    b_i::Vector{Float32}
    state0::Vector{Float32}
end
(m::RecModel)(h, x) = (relu(m.W_h*h .+ m.b_h), relu(m.W_i*x .+ m.W_h*h .+ m.b_i))
@functor RecModel (W_h, W_i, b_h, b_i)

m = RecModel(Flux.glorot_normal(10, 10), Flux.glorot_normal(10, 10), randn(Float32, 10), randn(Float32, 10), zeros(Float32, 10))
model = Flux.Recur(m, zeros(Float32, 10))
p = Flux.params(model)

loss(x, y) = Flux.mse(model(x), y)
loss(rand(Float32, 10), rand(Float32, 10))

g = Flux.gradient(()->loss(rand(Float32, 10), rand(Float32, 10)), p)

and without Flux.Recur:

mutable struct MyRecur{T, S}
    cell::T
    state::S
end
@functor MyRecur
Flux.trainable(m::MyRecur) = (m.cell,)
function (m::MyRecur)(x)
    m.state, y = m.cell(m.state, x)
    y
end
my_reset!(m::MyRecur) = 
    m.state = m.cell.state0 

m = RecModel(Flux.glorot_normal(10, 10), Flux.glorot_normal(10, 10), randn(Float32, 10), randn(Float32, 10), zeros(Float32, 10))
model = MyRecur(m, zeros(Float32, 10))
p = Flux.params(model)

loss(x, y) = Flux.mse(model(x), y)
loss(rand(Float32, 10), rand(Float32, 10))
    
g = Flux.gradient(()->loss(rand(Float32, 10), rand(Float32, 10)), p)