Flux: concatenate layers

I’m looking for a way to concatenate layers in Flux like this:

Would this work for that? The input is an array of inputs to the layers and the output is the concatenation of the layers’ outputs. Creating it would be something like Chain(Concat(Chain(A1, B1), Chain(A2, B2, C2)), D, E)

using Flux
using Flux: @treelike

struct Concat
    layers::Array
end

@treelike Concat

function (c::Concat)(inputs::Array)
    output = []
    for i in 1:length(c.layers)
        append!(output, c.layers[i](inputs[i]))
    end
    output
end
1 Like

You can simply do

Chain(x -> cat(Chain(A1, B1)(x), Chain(A2, B2, C2)(x), dims=3), # Concatenating along channel dimension
          D, E)
2 Likes

Thanks!

Can you help explain why you are concatenating along the channel dimension (dim=3)?

The channel dimension is typically synonymous with “features”. We typically assume that a network is learning features and combining multiple networks for further processing would be for combining features in another network.

If this is the case, why is it that when I run the following code to attempt to reproduce the suggestion above:

using Flux
A1 = Dense(5, 5)
B1 = Dense(5, 5)
A2 = Dense(5, 5)
B2 = Dense(5, 5)
C2 = Dense(5, 5)
D = Dense(5, 5)
E = Dense(5, 5)
model = Chain(x -> cat(Chain(A1, B1)(x), Chain(A2, B2, C2)(x); dims=3), D, E)

model(rand(5))

ERROR: MethodError: no method matching *(::TrackedArray{…,Array{Float32,2}}, ::TrackedArray{…,Array{Float32,3}})

I get the above error when I try to run the model on any input? It does not seem that concatenating along the third dimension works.

1 Like

That’s because the previous respondents for some reason assumed that the layers where convolutional and the input two-dimensional (images). Dense, on the other hand, takes in and returns a vector (or a batch of vectors, stored as a matrix). This will work:

using Flux
A1 = Dense(5, 5)
B1 = Dense(5, 5)
A2 = Dense(5, 5)
B2 = Dense(5, 5)
C2 = Dense(5, 5)
D = Dense(10, 5)
E = Dense(5, 5)
model = Chain(x -> cat(Chain(A1, B1)(x), Chain(A2, B2, C2)(x); dims=1), D, E)

model(rand(5))

or, equivalently,

model = Chain(x -> vcat(Chain(A1, B1)(x), Chain(A2, B2, C2)(x)), D, E)

(Note that we have to write D = Dense(10, 5) instead of D = Dense(5, 5).)

3 Likes

As @tomerarnon suggested in the answer to this question, the solution proposed here doesn’t work properly. A workaround is to use the Concat struct.