I’m trying to implement a custom container layer in Lux, following the manual to bind a custom masked linear layer together. The idea is that each masked layer has a user-defined mask, and all masks need to be set for the entire model when calling initialstates(rng, container_layer)
I’m having difficulties understanding how (and why) to work with a named tuple as the type parameter for the container layer.
The definition of the individual layer is
struct DenseMasked <: LuxCore.AbstractExplicitLayer
LuxCore.initialstates(::AbstractRNG, d::DenseMasked) = (mask=BitArray(true for i=1:d.out_dims, j=1:d.in_dims), )
@inline function (d::DenseMasked)(x::AbstractMatrix, ps, st::NamedTuple)
y = match_eltype(d, ps, st, x)
return (
d.activation, st.mask .* ps.weight, y, Lux._vec(Lux._getproperty(ps, Val(:bias)))),
I also have a function to calculate the mask matrix (which I set in states)
function sample_masks_deep(input_dim, hidden_dims; ordering=:natural)
This setup works when I instantiate like this:
num_input = 2
num_hidden = 8
model = Chain(DenseMasked(num_input, num_hidden, relu), DenseMasked(num_hidden, 2*num_input))
ps, st = Lux.setup(rng, model)
_, MW_list = sample_masks_deep(num_input, [num_hidden])
st.layer_1.mask[:] .= MW_list[1][:];
st.layer_2.mask[:] .= MW_list[2][:];
Now I want to define a custom container layer where this setup is handled in initialstates
, similar to SkipConnection
@concrete struct MADELayer <: LuxCore.AbstractExplicitContainerLayer{(:layers, )}
layers <: NamedTuple
function MADELayer(xs...; name::Lux.NAME_TYPE=nothing)
return MADELayer(Lux.__named_tuple_layers(xs...), name)
# Calls applychain as defined for Chain
(m::MADELayer)(x, ps, st) = Lux.applychain(model.layers, x, ps, st)
# This is where masks will be initialized
function LuxCore.initialstates(rng::AbstractRNG, l::MADELayer{layers}) where {layers}
@info "initialstates"
The implementation for SkipConnection
function initialstates(rng::AbstractRNG, l::SkipConnection{T, <:AbstractExplicitLayer}) where {T}
return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers)))
uses l::SkipConnection{T, <:AbstractExplicitLayer}
, while the struct itself is defined as @concrete struct SkipConnection <: AbstractExplicitContainerLayer{(:layers,)}
When I use this definition in my implementation above, I get an error:
ERROR: MethodError: no method matching initialstates(::Xoshiro, ::MADELayer{@NamedTuple{layer_1::DenseMasked, layer_2::DenseMasked}, String})
Closest candidates are:
initialstates(::AbstractRNG, ::MADELayer{T, <:LuxCore.AbstractExplicitContainerLayer}) where T
@ Main ~/julia_envs/lux_playground/src/deep_maf.jl:37
Why is the decoration of the function argument different than its definition here?
And is there a reason for using the type parameter interface to iterate over the layers? What does a custom implementation need to return?