Using Flux.jl, I would like to build Feed-Forward Neural Network, but with shortcut connections like those found in ResNet. What would be the best way to do this?
I know that something like Chain(Dense(10, 5, relu), Dense(5, 4, relu), Dense(4, 2, relu))
would give me feed-forward NN, but how can I incorporate shortcut connections into that?
I’m more familiar with Knet but I’ll give this a shot. If you look at the docs (http://fluxml.ai/Flux.jl/stable/models/basics.html), you will see that you can define your own prediction and loss functions like so
W = param(rand(2, 5))
b = param(rand(2))
predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)
You can use vcat
in Knet and it looks like you can use it in Flux too. Note that I have not verified whether the gradients are correct when using vcat
in Flux.
W1 = param(rand(2, 5))
b1 = param(rand(2))
W2 = param(rand(2, 7))
b2 = param(rand(2))
predict(x) = W2 * vcat(W1*x .+ b1, x) .+ b2
loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2)
l = loss(x,y)
Flux.back!(l)
W1.grad
b1.grad
W2.grad
b2.grad
You can think of a layer like layer1 = Dense(5,2,σ)
and layer2 = Dense(7,2,σ)
as functions and define predict similarly using vcat
:
predict(x) = layer2(vcat(layer1(x), x))
I feel like there are two partial questions here. One is about hooking into this Chain
, Dense
syntax, and the other is about resnet skip connections
I am going to naively sketch two types of skip connections, which I will call IdentitySkip
and CatSkip
. The first is based on the later resnet variation sometimes referred to as the pre-activation version (see [1603.05027] Identity Mappings in Deep Residual Networks). The later is the based on concatenation, similar how Dense Conv Nets do it ([1608.06993] Densely Connected Convolutional Networks) and @vvjn sketched. My examples use simple feature matrices instead of higher dimensional arrays, but I hope you get the idea.
(Note that I don’t use Flux much, so take this with a grain of salt. It seems to work though)
julia> using Flux
julia> struct IdentitySkip
inner
end
julia> struct CatSkip
inner
end
julia> (m::IdentitySkip)(x) = m.inner(x) .+ x
julia> (m::CatSkip)(x) = vcat(m.inner(x), x)
julia> m = Chain(Dense(2,3), IdentitySkip(Dense(3, 3)), Dense(3,4))
Chain(Dense(2, 3), IdentitySkip(Dense(3, 3)), Dense(3, 4))
julia> m(rand(2,5))
Tracked 4×5 Array{Float64,2}:
0.806883 -0.0375264 0.139005 0.441874 0.0202739
-0.447715 0.549833 0.349582 -0.0181219 0.0610884
-0.474843 0.299503 0.141969 -0.140966 0.0260037
-0.0398341 0.400103 0.314375 0.149121 0.0534565
One step to add here is to tell Flux where to find the parameters inside these types:
julia> Flux.params(m) |> length
4
julia> Flux.treelike(IdentitySkip)
julia> Flux.params(m) |> length
6
@Evizero, the IdentitySkip
will work for skipping the most recent transformation, k=1. Do you think it might be possible to modify this struct to allow skipping for some k layers past, where k>1?
There shouldn’t be anything special about just having one inner operation. Without testing it myself, I am guessing it should probably just work fine if you write IdentitySkip(Chain(Dense(3,3),Dense(3,3)))
. Alternatively, add more member variables or replace it with a vector/tuple of inner operations.
Amazing! Thanks a lot @Evizero!
Would it better to allow for the specification of an activation function?
struct IdentitySkip
inner
activation
end
(m::IdentitySkip)(x) = m.activation.(m.inner(x) .+ x)
This is how I understand it from Andrew Ng’s lectures
policy = Flux.Chain(
Dense(16*14, 128, relu),
IdentitySkip(Dense(128, 128), relu),
Dense(128, 32, relu),
IdentitySkip(Dense(32, 32), identity),
Dense(32, 4),
softmax
)
You can combine it however you desire. I didn’t include any activiation functions simply because the preactivation formulation used for the identity skip i reference uses a different ordering of things (which would just make everything look more complicated than it needs to).
Note though that the version you propose (and indeed the one Prof. Ng discusses) is from the original Resnet paper, and not the identity skip version from the later revision that i reference. see the paper I linked in my earlier post.
Causing a crashing if I use the GPU. See
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using CuArrays
# Classify MNIST digits with a convolutional network
imgs = MNIST.images()
labels = onehotbatch(MNIST.labels(), 0:9)
# Partition into batches of size 32
train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])
for i in partition(1:60_000, 32)]
train = gpu.(train)
# Prepare test set (first 1,000 images)
#tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tX = reshape(reduce(hcat, vec.(float.(MNIST.images(:test)))),28,28,1,10_000) |> gpu
tY = onehotbatch(MNIST.labels(:test), 0:9) |> gpu
trainX = reshape(reduce(hcat, vec.(float.(MNIST.images()))),28,28,1,60_000) |> gpu
trainY = onehotbatch(MNIST.labels(), 0:9) |> gpu
struct IdentitySkip
inner
end
(m::IdentitySkip)(x) = m.inner(x) .+ x
m = Chain(
Conv((2, 2), 1=>32, relu),
x -> maxpool(x, (2,2)),
Conv((2, 2), 32=>32, relu),
x -> maxpool(x, (2,2)),
Conv((2, 2), 32=>32, relu),
x -> reshape(x, :, size(x, 4)),
Dense(800, 100, relu),
IdentitySkip(Dense(100, 100, relu)),
Dense(100, 10),
softmax) |> gpu
It looks like this was implemented directly in Flux (Model Reference · Flux) can you give that a try?