Best way to implement shortcut connections for feed-forward neural networks in Flux.jl



Using Flux.jl, I would like to build Feed-Forward Neural Network, but with shortcut connections like those found in ResNet. What would be the best way to do this?
I know that something like Chain(Dense(10, 5, relu), Dense(5, 4, relu), Dense(4, 2, relu)) would give me feed-forward NN, but how can I incorporate shortcut connections into that?


I’m more familiar with Knet but I’ll give this a shot. If you look at the docs (, you will see that you can define your own prediction and loss functions like so

W = param(rand(2, 5))                                                                                         
b = param(rand(2))                                                                                            
predict(x) = W*x .+ b                                                                                         
loss(x, y) = sum((predict(x) .- y).^2)                                                                        

You can use vcat in Knet and it looks like you can use it in Flux too. Note that I have not verified whether the gradients are correct when using vcat in Flux.

W1 = param(rand(2, 5))                                                                                        
b1 = param(rand(2))                                                                                           
W2 = param(rand(2, 7))                                                                                        
b2 = param(rand(2))                                                                                           
predict(x) = W2 * vcat(W1*x .+ b1, x) .+ b2                                                                   
loss(x, y) = sum((predict(x) .- y).^2)                                                                        
x, y = rand(5), rand(2)                                                                                       
l = loss(x,y)                                                                                                 

You can think of a layer like layer1 = Dense(5,2,σ) and layer2 = Dense(7,2,σ) as functions and define predict similarly using vcat:

predict(x) = layer2(vcat(layer1(x), x))                                                                       


I feel like there are two partial questions here. One is about hooking into this Chain, Dense syntax, and the other is about resnet skip connections

I am going to naively sketch two types of skip connections, which I will call IdentitySkip and CatSkip. The first is based on the later resnet variation sometimes referred to as the pre-activation version (see The later is the based on concatenation, similar how Dense Conv Nets do it ( and @vvjn sketched. My examples use simple feature matrices instead of higher dimensional arrays, but I hope you get the idea.

(Note that I don’t use Flux much, so take this with a grain of salt. It seems to work though)

julia> using Flux

julia> struct IdentitySkip

julia> struct CatSkip

julia> (m::IdentitySkip)(x) = m.inner(x) .+ x

julia> (m::CatSkip)(x) = vcat(m.inner(x), x)

julia> m = Chain(Dense(2,3), IdentitySkip(Dense(3, 3)), Dense(3,4))
Chain(Dense(2, 3), IdentitySkip(Dense(3, 3)), Dense(3, 4))

julia> m(rand(2,5))
Tracked 4×5 Array{Float64,2}:
  0.806883   -0.0375264  0.139005   0.441874   0.0202739
 -0.447715    0.549833   0.349582  -0.0181219  0.0610884
 -0.474843    0.299503   0.141969  -0.140966   0.0260037
 -0.0398341   0.400103   0.314375   0.149121   0.0534565


One step to add here is to tell Flux where to find the parameters inside these types:

julia> Flux.params(m) |> length

julia> Flux.treelike(IdentitySkip)

julia> Flux.params(m) |> length


@Evizero, the IdentitySkip will work for skipping the most recent transformation, k=1. Do you think it might be possible to modify this struct to allow skipping for some k layers past, where k>1?


There shouldn’t be anything special about just having one inner operation. Without testing it myself, I am guessing it should probably just work fine if you write IdentitySkip(Chain(Dense(3,3),Dense(3,3))). Alternatively, add more member variables or replace it with a vector/tuple of inner operations.


Amazing! Thanks a lot @Evizero!