Re-using layers in Flux.jl: how to train a multi-layer model sharing a common LSTM layer and separate dense layers?

Right now, you are also mapping the dense layer over every element in the sequence. I presume that is not intended? If not, then the loss function should first apply mycommonlayer over x twice to get y_1 and y_1, then call mydense1 and mydense2 on those.

You can pass mycommonlayer, mydense1 and mydense2 into your loss function (you can toss them into a (named)tuple if you still want model as a single arg). mycommonlayer will be automatically shared.

Secondly, I would recommend pre-arranging your input into a vector of views. Doing it in the loss where the AD can see it will be prohibitively slow.

Putting all this together:

function apply_model(rnn, dense1, dense2, x)
    Zygote.ignore(() -> Flux.reset!(rnn)) # we don't need to differentiate this
    y_1 = [rnn(step) for step in x][end]

    Zygote.ignore(() -> Flux.reset!(rnn)) # we don't need to differentiate this
    y_2 = [rnn(step) for step in x][end]

    return (mydense1(y_1) .+ mydense2(y_2)) ./ 2
end

Alternatively, you could package up the (dense + dense) / 2 part into a single layer:

addscale(y_1, y_2) = (y_1 .+ y_2) ./ 2
post_rnn = Parallel(addscale, mydense1, mydense2)
2 Likes