Creating a custom container layer in Lux

Hi,
I’m trying to implement a custom container layer in Lux, following the manual to bind a custom masked linear layer together. The idea is that each masked layer has a user-defined mask, and all masks need to be set for the entire model when calling initialstates(rng, container_layer)

I’m having difficulties understanding how (and why) to work with a named tuple as the type parameter for the container layer.

The definition of the individual layer is

struct DenseMasked <: LuxCore.AbstractExplicitLayer
    activation
    in_dims::Int
    out_dims::Int
    init_weight
    init_bias
end

LuxCore.initialstates(::AbstractRNG, d::DenseMasked) = (mask=BitArray(true for i=1:d.out_dims, j=1:d.in_dims), )


@inline function (d::DenseMasked)(x::AbstractMatrix, ps, st::NamedTuple)
    y = match_eltype(d, ps, st, x)
    return (
        fused_dense_bias_activation(
            d.activation, st.mask .* ps.weight, y, Lux._vec(Lux._getproperty(ps, Val(:bias)))),
        st)
end

I also have a function to calculate the mask matrix (which I set in states)

function sample_masks_deep(input_dim, hidden_dims; ordering=:natural)

This setup works when I instantiate like this:

num_input = 2
num_hidden = 8
model = Chain(DenseMasked(num_input, num_hidden, relu), DenseMasked(num_hidden, 2*num_input))
ps, st = Lux.setup(rng, model)
_, MW_list = sample_masks_deep(num_input, [num_hidden])

st.layer_1.mask[:] .= MW_list[1][:];
st.layer_2.mask[:] .= MW_list[2][:];

Now I want to define a custom container layer where this setup is handled in initialstates, similar to SkipConnection

@concrete struct MADELayer <: LuxCore.AbstractExplicitContainerLayer{(:layers, )}
    layers <: NamedTuple
    name
end

function MADELayer(xs...; name::Lux.NAME_TYPE=nothing)
    return MADELayer(Lux.__named_tuple_layers(xs...), name)
end


# Calls applychain as defined for Chain
(m::MADELayer)(x, ps, st) = Lux.applychain(model.layers, x, ps, st)

# This is where masks will be initialized
function LuxCore.initialstates(rng::AbstractRNG, l::MADELayer{layers}) where {layers}
    @info "initialstates"
end

The implementation for SkipConnection

function initialstates(rng::AbstractRNG, l::SkipConnection{T, <:AbstractExplicitLayer}) where {T}
    return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers)))
end

uses l::SkipConnection{T, <:AbstractExplicitLayer}, while the struct itself is defined as @concrete struct SkipConnection <: AbstractExplicitContainerLayer{(:layers,)}.

When I use this definition in my implementation above, I get an error:


ERROR: MethodError: no method matching initialstates(::Xoshiro, ::MADELayer{@NamedTuple{layer_1::DenseMasked, layer_2::DenseMasked}, String})

Closest candidates are:
  initialstates(::AbstractRNG, ::MADELayer{T, <:LuxCore.AbstractExplicitContainerLayer}) where T
   @ Main ~/julia_envs/lux_playground/src/deep_maf.jl:37

Why is the decoration of the function argument different than its definition here?
And is there a reason for using the type parameter interface to iterate over the layers? What does a custom implementation need to return?

When I use this definition in my implementation above, I get an error:

Yes that is because SkipConnection dispatches to the fallback implementation in LuxCore. The definition that you link is for a very special case where the connection is also a layer (in contrast to a basic operation)

And is there a reason for using the type parameter interface to iterate over the layers?

Part of it stems from older versions of julia (way back in 1.6) where there were issues getting type-inference to work through an entire chain. Currently it exists for safety to ensure type inference almost never fails. However, in user code you should just iterate over l.layers instead of the type parameters.

What does a custom implementation need to return?

You code will probably look something like:

function LuxCore.initialstates(rng::AbstractRNG, layer::MADELayer)
    st = LuxCore.initialstates(rng, layer.layers)
    _, MW_list = sample_masks_deep(....)
    for (I, st_) in enumerate(values(layer))
        st_.mask[:] .= MW_list[I][:]
    end
    return st
end

Generally speaking, it needs to return a NamedTuple containing everything that is needed to define your dispatch (model::MADELayer)(x, ps, st)

1 Like

Thanks for these insights, they are quiet helpful when extending lux with custom layers.