The issue I am pointing at arrives when implementing encoder for transformer models with masking. Below I have tried to do a fake minimal example instead.
We have a custom layer which takes a couple of inputs. For a transformer this can be input_feature, mask
using Flux
using Flux: @functor
struct MyCustomLayer
l::Dense
end
@functor MyCustomLayer
(m::MyCustomLayer)(x, y) = m.l(x) .+ y
MyCustomLayer(in_dim, out_dim) = MyCustomLayer(Dense(in_dim, out_dim))
Now we can define a simple block something like this
Block(in_dim, hidden_dim, out_dim) = Chain(Dense(in_dim, hidden_dim), MyCustomLayer(hidden_dim, out_dim))
Now this will obviously break because MyCustomLayer takes an additional input. This can be fixed ofcourse by declaring Block as a struct and a custom layer which takes two inputs as well.
However, I find that inelegant and flux like because now if I want to use Block as a component in any higher layer say an Encoder I have no choice but to define that as struct too! This leads to the whole model being structs which I would like to avoid since we can’t make use of the nice container layers which makes code super readable.
Any suggestions are welcome. I have also considered defining a custom struct for packing the inputs:
struct PackedInputs
features
mask
end
and then defining custom forward passes for bunch of layers like Dense, LayerNorm such that most layers only act on the features and the layers that require both the mask and feature use them. Not sure if that is the best we can do.