If it really has to be a chain, I guess this does what you are after:
julia> d1 = Dense(8,8, relu);
julia> d2 = Dense(8,8,relu);
julia> d3 = Dense(16,2);
julia> chain = Chain(x -> (x, d1(x)), ((x, x1)::Tuple) -> (x, d2(x1)), ((x,x2)::Tuple) -> cat(x,x2,
dims=1), d3);
julia> chain(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
-0.5452 -0.5452 -0.5452
-0.761695 -0.761695 -0.761695
Flux also exports SkipConnection for cases like this:
julia> chain = Chain(SkipConnection(Chain(d1,d2), (x1, x) -> cat(x,x1,dims=1)), d3);
julia> chain(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
-0.5452 -0.5452 -0.5452
-0.761695 -0.761695 -0.761695
Note that Flux does not require things to be a Chain. Any function will do:
julia> model = function(x)
x1 = d1(x)
x2 = d2(x1)
x3 = cat(x, x2, dims=1)
return d3(x3)
end
#53 (generic function with 1 method)
julia> model(ones(Float32, 8, 3))
2×3 Array{Float32,2}:
-0.5452 -0.5452 -0.5452
-0.761695 -0.761695 -0.761695
If you want to build and modify models programatically you can try GitHub - DrChainsaw/NaiveNASflux.jl: Your local Flux surgeon although it might be overkill for what you are trying to achive.