Creating a custom container layer in Lux

When I use this definition in my implementation above, I get an error:

Yes that is because SkipConnection dispatches to the fallback implementation in LuxCore. The definition that you link is for a very special case where the connection is also a layer (in contrast to a basic operation)

And is there a reason for using the type parameter interface to iterate over the layers?

Part of it stems from older versions of julia (way back in 1.6) where there were issues getting type-inference to work through an entire chain. Currently it exists for safety to ensure type inference almost never fails. However, in user code you should just iterate over l.layers instead of the type parameters.

What does a custom implementation need to return?

You code will probably look something like:

function LuxCore.initialstates(rng::AbstractRNG, layer::MADELayer)
    st = LuxCore.initialstates(rng, layer.layers)
    _, MW_list = sample_masks_deep(....)
    for (I, st_) in enumerate(values(layer))
        st_.mask[:] .= MW_list[I][:]
    end
    return st
end

Generally speaking, it needs to return a NamedTuple containing everything that is needed to define your dispatch (model::MADELayer)(x, ps, st)

1 Like