The right way to implement a residual network (ResNet)

I am not sure if the following implementation of ResNet is correct. Any suggestion?

What I am not sure are:
(1) for the network of nn_x, will the ResNet block sm share parameters? I don’t expect they share the same parameters.

nodes = 32
nq = 32

m = Chain(Dense(nodes, nodes, mish),
          Dense(nodes, nodes, mish),
          Dense(nodes, nodes, mish))

sm = SkipConnection(m, +);

# TODO: use iterator to construct NN
nn_x = Chain(Dense(n_var_in - 1, nodes, mish),
           sm, sm, sm, sm, sm, sm, sm, sm, sm, sm,
           Dense(nodes, nq * n_var_out)) |> gpu


nn_t = Chain(Dense(1, nodes, mish),
            sm, sm, sm, sm, sm, sm, sm, sm, sm, sm,
            Dense(nodes, nq)) |> gpu

p_all = Flux.params(nn_x, nn_t)

function pred(x)
    Gu = reshape(nn_x(x[2:end, :]), n_var_out, nq, :);
    tq = nn_t(x[1:1, :])
    Gutq = hcat([Gu[:, :, i] * tq[:, i] for i in 1:size(x)[2]]...)
    return Gutq
end

They will all share params. Calling Dense(...) is what initializes the params, so what happens is that you model will just pass the output from sm into sm again a few times. You might want to put the m=... and sm=... lines of your example in a function and call it multiple times to get independent instances.

I haven’t scrutinised the rest of the code, but it looks ok to me at a glance.

1 Like

https://github.com/FluxML/Metalhead.jl/blob/master/src/resnet.jl

Might want to use this or use as a reference.

2 Likes