Best way to implement shortcut connections for feed-forward neural networks in Flux.jl

Using Flux.jl, I would like to build Feed-Forward Neural Network, but with shortcut connections like those found in ResNet. What would be the best way to do this?
I know that something like Chain(Dense(10, 5, relu), Dense(5, 4, relu), Dense(4, 2, relu)) would give me feed-forward NN, but how can I incorporate shortcut connections into that?

1 Like

I’m more familiar with Knet but I’ll give this a shot. If you look at the docs (http://fluxml.ai/Flux.jl/stable/models/basics.html), you will see that you can define your own prediction and loss functions like so

W = param(rand(2, 5))                                                                                         
b = param(rand(2))                                                                                            
predict(x) = W*x .+ b                                                                                         
loss(x, y) = sum((predict(x) .- y).^2)                                                                        

You can use vcat in Knet and it looks like you can use it in Flux too. Note that I have not verified whether the gradients are correct when using vcat in Flux.

W1 = param(rand(2, 5))                                                                                        
b1 = param(rand(2))                                                                                           
W2 = param(rand(2, 7))                                                                                        
b2 = param(rand(2))                                                                                           
predict(x) = W2 * vcat(W1*x .+ b1, x) .+ b2                                                                   
loss(x, y) = sum((predict(x) .- y).^2)                                                                        
                                                                                                              
x, y = rand(5), rand(2)                                                                                       
l = loss(x,y)                                                                                                 
Flux.back!(l)                                                                                                 
W1.grad                                                                                                       
b1.grad                                                                                                       
W2.grad                                                                                                       
b2.grad                                                                                                       

You can think of a layer like layer1 = Dense(5,2,σ) and layer2 = Dense(7,2,σ) as functions and define predict similarly using vcat:

predict(x) = layer2(vcat(layer1(x), x))                                                                       

I feel like there are two partial questions here. One is about hooking into this Chain, Dense syntax, and the other is about resnet skip connections

I am going to naively sketch two types of skip connections, which I will call IdentitySkip and CatSkip. The first is based on the later resnet variation sometimes referred to as the pre-activation version (see [1603.05027] Identity Mappings in Deep Residual Networks). The later is the based on concatenation, similar how Dense Conv Nets do it ([1608.06993] Densely Connected Convolutional Networks) and @vvjn sketched. My examples use simple feature matrices instead of higher dimensional arrays, but I hope you get the idea.

(Note that I don’t use Flux much, so take this with a grain of salt. It seems to work though)

julia> using Flux

julia> struct IdentitySkip
           inner
       end

julia> struct CatSkip
           inner
       end

julia> (m::IdentitySkip)(x) = m.inner(x) .+ x

julia> (m::CatSkip)(x) = vcat(m.inner(x), x)

julia> m = Chain(Dense(2,3), IdentitySkip(Dense(3, 3)), Dense(3,4))
Chain(Dense(2, 3), IdentitySkip(Dense(3, 3)), Dense(3, 4))

julia> m(rand(2,5))
Tracked 4×5 Array{Float64,2}:
  0.806883   -0.0375264  0.139005   0.441874   0.0202739
 -0.447715    0.549833   0.349582  -0.0181219  0.0610884
 -0.474843    0.299503   0.141969  -0.140966   0.0260037
 -0.0398341   0.400103   0.314375   0.149121   0.0534565
4 Likes

One step to add here is to tell Flux where to find the parameters inside these types:

julia> Flux.params(m) |> length
4

julia> Flux.treelike(IdentitySkip)

julia> Flux.params(m) |> length
6
1 Like

@Evizero, the IdentitySkip will work for skipping the most recent transformation, k=1. Do you think it might be possible to modify this struct to allow skipping for some k layers past, where k>1?

There shouldn’t be anything special about just having one inner operation. Without testing it myself, I am guessing it should probably just work fine if you write IdentitySkip(Chain(Dense(3,3),Dense(3,3))). Alternatively, add more member variables or replace it with a vector/tuple of inner operations.

Amazing! Thanks a lot @Evizero!

Would it better to allow for the specification of an activation function?

struct IdentitySkip
   inner
   activation
end

(m::IdentitySkip)(x) = m.activation.(m.inner(x) .+ x)

This is how I understand it from Andrew Ng’s lectures

policy = Flux.Chain(
  Dense(16*14, 128, relu),
  IdentitySkip(Dense(128, 128), relu),
  Dense(128, 32, relu),
  IdentitySkip(Dense(32, 32), identity),
  Dense(32, 4),
  softmax
  )

You can combine it however you desire. I didn’t include any activiation functions simply because the preactivation formulation used for the identity skip i reference uses a different ordering of things (which would just make everything look more complicated than it needs to).

Note though that the version you propose (and indeed the one Prof. Ng discusses) is from the original Resnet paper, and not the identity skip version from the later revision that i reference. see the paper I linked in my earlier post.

1 Like

Causing a crashing if I use the GPU. See

using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using CuArrays
# Classify MNIST digits with a convolutional network

imgs = MNIST.images()

labels = onehotbatch(MNIST.labels(), 0:9)

# Partition into batches of size 32
train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])
         for i in partition(1:60_000, 32)]

train = gpu.(train)

# Prepare test set (first 1,000 images)
#tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tX = reshape(reduce(hcat, vec.(float.(MNIST.images(:test)))),28,28,1,10_000) |> gpu
tY = onehotbatch(MNIST.labels(:test), 0:9) |> gpu

trainX = reshape(reduce(hcat, vec.(float.(MNIST.images()))),28,28,1,60_000) |> gpu
trainY = onehotbatch(MNIST.labels(), 0:9) |> gpu

struct IdentitySkip
   inner
end

(m::IdentitySkip)(x) = m.inner(x) .+ x

m = Chain(
    Conv((2, 2), 1=>32, relu),
    x -> maxpool(x, (2,2)),
    Conv((2, 2), 32=>32, relu),
    x -> maxpool(x, (2,2)),
    Conv((2, 2), 32=>32, relu),
    x -> reshape(x, :, size(x, 4)),
    Dense(800, 100, relu),
    IdentitySkip(Dense(100, 100, relu)),
    Dense(100, 10),
    softmax) |> gpu
1 Like

It looks like this was implemented directly in Flux (Model Reference · Flux) can you give that a try?

2 Likes