Chain multiple Chains in Flux

Hello!
Is there any way to Chain multiple chains in Flux?
Say for example I want to reformat my model from multiple uses of the same layers with different parameters as follows.

function conv_block(in_channels, out_channels)
   (Conv((3,3), in_channels => out_channels, relu, pad = (1,1), stride = (1,1)),
    BatchNorm(out_channels))
end

Instead of

 Conv((3, 3), 3 => 64, relu, pad=(1, 1), stride=(1, 1)),
  BatchNorm(64),
  Conv((3, 3), 64 => 64, relu, pad=(1, 1), stride=(1, 1)),
  BatchNorm(64),

which would ultimately be something like this.

Chain(
    conv_block(3, 128),
    conv_block(128, 256),
)
2 Likes

You can splat the tuples/Chains:

Chain(
    conv_block(3, 128)...,
    conv_block(128, 256)...,
)

You can also define

 chain(A::Tuple...) = Chain(mapreduce(collect, vcat, A)...)

to do

chain(conv_block(3, 128), conv_block(128, 3))

directly.

1 Like

OH!
Thank you so much :slight_smile: