Splitting and joining Flux model chains

I would like to create a Flux model that splits the input into two, processes it in two different ways and then joins it together using a model. Chain does not seem to allow parallel data paths. The diagram is here:NetworkSplit|186x321

I could make this work if I could stack an identity and model next to each other. For example:

Chain(x -> cat(x, x), Stack( Identity(8), Dense(8, 8, relu) ), Stack( Identity(8), Dense(8, 8, relu), Dense(16,2) )

I would think that this should be possible, I am just not experienced enough.

If it really has to be a chain, I guess this does what you are after:

julia> d1 = Dense(8,8, relu);

julia> d2 = Dense(8,8,relu);

julia> d3 = Dense(16,2);

julia> chain = Chain(x -> (x, d1(x)), ((x, x1)::Tuple) -> (x, d2(x1)), ((x,x2)::Tuple) -> cat(x,x2, 
dims=1), d3);

julia> chain(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
 -0.5452    -0.5452    -0.5452
 -0.761695  -0.761695  -0.761695

Flux also exports SkipConnection for cases like this:

julia> chain = Chain(SkipConnection(Chain(d1,d2), (x1, x) -> cat(x,x1,dims=1)), d3);

julia> chain(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
 -0.5452    -0.5452    -0.5452
 -0.761695  -0.761695  -0.761695

Note that Flux does not require things to be a Chain. Any function will do:

julia> model = function(x)
       x1 = d1(x)
       x2 = d2(x1)
       x3 = cat(x, x2, dims=1)
       return d3(x3)
       end
#53 (generic function with 1 method)

julia> model(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
 -0.5452    -0.5452    -0.5452
 -0.761695  -0.761695  -0.761695

If you want to build and modify models programatically you can try https://github.com/DrChainsaw/NaiveNASflux.jl although it might be overkill for what you are trying to achive.

2 Likes