Splitting and joining Flux model chains

I would like to create a Flux model that splits the input into two, processes it in two different ways and then joins it together using a model. Chain does not seem to allow parallel data paths. The diagram is here:NetworkSplit|186x321

I could make this work if I could stack an identity and model next to each other. For example:

Chain(x -> cat(x, x), Stack( Identity(8), Dense(8, 8, relu) ), Stack( Identity(8), Dense(8, 8, relu), Dense(16,2) )

I would think that this should be possible, I am just not experienced enough.

1 Like

If it really has to be a chain, I guess this does what you are after:

julia> d1 = Dense(8,8, relu);

julia> d2 = Dense(8,8,relu);

julia> d3 = Dense(16,2);

julia> chain = Chain(x -> (x, d1(x)), ((x, x1)::Tuple) -> (x, d2(x1)), ((x,x2)::Tuple) -> cat(x,x2, 
dims=1), d3);

julia> chain(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
 -0.5452    -0.5452    -0.5452
 -0.761695  -0.761695  -0.761695

Flux also exports SkipConnection for cases like this:

julia> chain = Chain(SkipConnection(Chain(d1,d2), (x1, x) -> cat(x,x1,dims=1)), d3);

julia> chain(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
 -0.5452    -0.5452    -0.5452
 -0.761695  -0.761695  -0.761695

Note that Flux does not require things to be a Chain. Any function will do:

julia> model = function(x)
       x1 = d1(x)
       x2 = d2(x1)
       x3 = cat(x, x2, dims=1)
       return d3(x3)
       end
#53 (generic function with 1 method)

julia> model(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
 -0.5452    -0.5452    -0.5452
 -0.761695  -0.761695  -0.761695

If you want to build and modify models programatically you can try GitHub - DrChainsaw/NaiveNASflux.jl: Your local Flux surgeon although it might be overkill for what you are trying to achive.

4 Likes

If I define the model as a generic function as in your last example, and then try to attach an optimiser with it as optim = Flux.setup(Flux.Adam(0.01), model), I get a warning that there are no trainable parameters in the model. Flux.params(model) also gives an empty list. How can one train such a model ?

With later versions of Flux (after the switch to explicit optimisers) I think the anonymous function approach is a bit cumbersome. You could perhaps create a named tuple with the same structure as the gradient (e.g. (;d1=d1, d2=d2, d3=d3)) and use that instead of model when interacting with optimisers, but I’m not sure it will work out all the way.

Afaik the canonical way is to just create a callable struct and make it a functor, something like this (sorry for completely untested code):

struct MyModel{D1,D2,D3}
 d1::D1
 d2::D2
 d3::D3
end

@functor MyModel

function (m::MyModel)(x)
   x1 = m.d1(x)
   x2 = m.d2(x1)
   x3 = cat(x, x2, dims=1)
   return m.d3(x3)
end
1 Like

It didn’t exist in 2020 when the OP was made, but this use case is exactly what Parallel was designed for.

1 Like