Elegant way to handle multiple input flux layers?

The issue I am pointing at arrives when implementing encoder for transformer models with masking. Below I have tried to do a fake minimal example instead.

We have a custom layer which takes a couple of inputs. For a transformer this can be input_feature, mask

using Flux
using Flux: @functor

struct MyCustomLayer
           l::Dense
end

@functor MyCustomLayer

(m::MyCustomLayer)(x, y) = m.l(x) .+ y

MyCustomLayer(in_dim, out_dim) = MyCustomLayer(Dense(in_dim, out_dim))

Now we can define a simple block something like this

Block(in_dim, hidden_dim, out_dim) = Chain(Dense(in_dim, hidden_dim), MyCustomLayer(hidden_dim, out_dim))

Now this will obviously break because MyCustomLayer takes an additional input. This can be fixed ofcourse by declaring Block as a struct and a custom layer which takes two inputs as well.

However, I find that inelegant and flux like because now if I want to use Block as a component in any higher layer say an Encoder I have no choice but to define that as struct too! This leads to the whole model being structs which I would like to avoid since we can’t make use of the nice container layers which makes code super readable.

Any suggestions are welcome. I have also considered defining a custom struct for packing the inputs:

struct PackedInputs
      features
      mask
end

and then defining custom forward passes for bunch of layers like Dense, LayerNorm such that most layers only act on the features and the layers that require both the mask and feature use them. Not sure if that is the best we can do.

This is one case where Parallel shines:

Block(in_dim, hidden_dim, out_dim) = Parallel(
  MyCustomLayer(hidden_dim, out_dim)); # "connection" joining x and y
  xbranch = Chain(first, Dense(in_dim, hidden_dim)), # feeds first element of tuple into Dense
  ybranch = last # extracts second element of tuple
)

The idea is that this block can be expressed as a more generalized skip connection. To use it, just call block(input_features, mask).

I am not sure this addresses my concerns fully. The block itself will be one of the the layers. The layers before and after it might just consume single input. In a complex network this seems like lots of Parallel throughout the code as the mask is accessed across several layers.

To give you an idea here is the actual sample code. Please ignore the mask as in the original implementation this was a field of the MHA. I had to change that because the mask might change for every batch.

function Block(;n_embed, n_head, block_size, attention_dropout, residual_dropout, mask)
    mha = MHSelfAttention(n_embed=n_embed, n_head=n_head, block_size=block_size, attention_dropout=attention_dropout,
                          residual_dropout=residual_dropout, mask=mask)
    MLP = Chain(Dense(n_embed, 4 * n_embed, gelu), Dense(4 * n_embed, n_embed), Dropout(residual_dropout))

    return Chain(SkipConnection(Chain(LayerNorm(n_embed), mha), +),
                 SkipConnection(Chain(LayerNorm(n_embed), MLP), +))
end

function PositionalAwareEmbedding(vocab_size, embed_size, block_size)
    embed = Flux.Embedding(vocab_size, embed_size)
    pos_embed = Chain(x -> 1:block_size, Flux.Embedding(block_size, embed_size))
    return Parallel(.+, embed, pos_embed) # change to remove dot
end

function Decoder(;vocab_size, n_layers, n_embed, n_head, block_size, attention_dropout, residual_dropout, decoder_mask=true)
    mask = decoder_mask ? make_decoder_mask(block_size) : zeros(block_size, block_size)
    # this mask is shared across all layers so we can ovverride it before forward pass
    embed = PositionalAwareEmbedding(vocab_size, n_embed, block_size)
    blocks = [Block(n_embed=n_embed, n_head=n_head, block_size=block_size, attention_dropout=attention_dropout,
                    residual_dropout=residual_dropout, mask=mask) for _ in 1:n_layers]
    lm_head = Dense(embed[1].weight') # n_embed, vocab_size (weight tying)
    return Chain(embed, blocks..., LayerNorm(n_embed), lm_head)
end

I am not sure if I am coming across clearly on why dropping Parallel through the code seems a bad idea for readability. It is because for every other layer than mha throughout the network we have to point it to ignore the mask and only process the first part

I don’t see how it would be too bad. You can always put Parallel at a higher level, and I’d argue having parallel throughout the code is good for readability because it clearly signals intent.

That said, there are other ways to approach this. You could make an IgnoreMask wrapper layer which allows the wrapped layer(s) to just access the input and re-wraps back into (features, mask) afterwards. You could define a masked array type which contains features + mask and dispatch on it in particular layers to pull out either/both parts as needed. You could copy (or depend on) what Transformers.jl does to handle more complex network topologies. There are plenty of options, so it comes down to which makes the most sense for your use case.

1 Like

Yep this is the route I think I would take. I just wanted to make sure that there is a not so obvious way I am totally missing. Thanks a ton for the ideas!