I am not sure if the following implementation of ResNet is correct. Any suggestion?
What I am not sure are:
(1) for the network of nn_x, will the ResNet block sm share parameters? I don’t expect they share the same parameters.
nodes = 32
nq = 32
m = Chain(Dense(nodes, nodes, mish),
Dense(nodes, nodes, mish),
Dense(nodes, nodes, mish))
sm = SkipConnection(m, +);
# TODO: use iterator to construct NN
nn_x = Chain(Dense(n_var_in - 1, nodes, mish),
sm, sm, sm, sm, sm, sm, sm, sm, sm, sm,
Dense(nodes, nq * n_var_out)) |> gpu
nn_t = Chain(Dense(1, nodes, mish),
sm, sm, sm, sm, sm, sm, sm, sm, sm, sm,
Dense(nodes, nq)) |> gpu
p_all = Flux.params(nn_x, nn_t)
function pred(x)
Gu = reshape(nn_x(x[2:end, :]), n_var_out, nq, :);
tq = nn_t(x[1:1, :])
Gutq = hcat([Gu[:, :, i] * tq[:, i] for i in 1:size(x)[2]]...)
return Gutq
end