Re-using layers in Flux.jl: how to train a multi-layer model sharing a common LSTM layer and separate dense layers?

I would like to train a model that consists of two “submodels”, say model1 and model2. The two models share a common LSTM layer and are followed by two separate dense layers. The final output is the average of the two models.

My (failed) attempt:

using Flux
using Flux.Data: DataLoader
using Flux.Optimise: Optimiser 

nfeatures = 11
seq_length = 28
nsamples = 1000
opt = ADAM(0.002)
device = cpu

# Generete toy samples
X = rand(Float32, nfeatures, seq_length, nsamples)
Y = sum(X[1, end-1:end, :], dims=1)

train_loader = DataLoader((X, Y), batchsize=64, shuffle=true)

function apply_model(model, x)

    model1 = Chain(model[:mycommomlayer], model[:mydense1])
    model2 = Chain(model[:mycommomlayer], model[:mydense2])

    Flux.reset!(model1)
    y_1 = last(map(model1, [view(x, :, t, :) for t in 1:seq_length]))

    Flux.reset!(model2)
    y_2 = last(map(model2, [view(x, :, t, :) for t in 1:seq_length]))

    return (y_1 + y_2)/2
end

model = Chain(mycommomlayer = LSTM(nfeatures, 32),
    mydense1 = Chain(Dropout(0.1), Dense(32, 1)),
    mydense2 = Chain(Dropout(0.1), Dense(32, 1)))

function loss(ŷ, y)
    l = Flux.mse(y, ŷ)
    println("loss is $l")
    return l
end

ps = Flux.params(model)

for epoch in 1:5
    @info(" start epoch $epoch ")
    for (x, y) in train_loader
        x, y = x |> device, y |> device
        gs = Flux.gradient(ps) do
            ŷ = apply_model(model, x)
            loss(ŷ, y)
        end
        Flux.Optimise.update!(opt, ps, gs)
    end
end

It crashes with

[ Info: start epoch 1
loss is 1.6848714
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any…) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/operators.jl:655
+(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/ksvfu/src/tangent_arithmetic.jl:122
+(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/ksvfu/src/tangent_arithmetic.jl:146

Stacktrace:
[1] accum(x::Base.RefValue{Any}, y::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:17
[2] macro expansion
@ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:57 [inlined]
[3] accum_param(cx::Zygote.Context, x::Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Δ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:53
[4] back
@ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:228 [inlined]
[5] #1765#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[6] Pullback
@ ./namedtuple.jl:127 [inlined]
[7] #213
@ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:203 [inlined]
[8] (::Zygote.var"#1754#back#215"{Zygote.var"#213#214"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, typeof(∂(getindex))}})(Δ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[9] Pullback
@ ~/.julia/packages/MacroTools/PP9IQ/src/examples/forward.jl:18 [inlined]
[10] (::Zygote.var"#213#214"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#getindex#149))})(Δ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:203
[11] #1754#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[12] Pullback
@ ~/.julia/packages/MacroTools/PP9IQ/src/examples/forward.jl:18 [inlined]
[13] Pullback
@ ~/.julia/packages/Zygote/cCyLF/src/tools/builtins.jl:15 [inlined]
[14] (::typeof(∂(literal_getindex)))(Δ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[15] Pullback
@ ~/Library/CloudStorage/Box-Box/MZ/proj/streamflow_ungauged/finetunning_physic/toy.jl:21 [inlined]
[16] (::typeof(∂(apply_model)))(Δ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[17] Pullback
@ ~/Library/CloudStorage/Box-Box/MZ/proj/streamflow_ungauged/finetunning_physic/toy.jl:49 [inlined]
[18] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[19] (::Zygote.var"#94#95"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:357
[20] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:76
[21] top-level scope
@ ~/Library/CloudStorage/Box-Box/MZ/proj/streamflow_ungauged/finetunning_physic/toy.jl:48

Please, how could I define the two models that share the LSTM layer?
Many thanks in advance!

Right now, you are also mapping the dense layer over every element in the sequence. I presume that is not intended? If not, then the loss function should first apply mycommonlayer over x twice to get y_1 and y_1, then call mydense1 and mydense2 on those.

You can pass mycommonlayer, mydense1 and mydense2 into your loss function (you can toss them into a (named)tuple if you still want model as a single arg). mycommonlayer will be automatically shared.

Secondly, I would recommend pre-arranging your input into a vector of views. Doing it in the loss where the AD can see it will be prohibitively slow.

Putting all this together:

function apply_model(rnn, dense1, dense2, x)
    Zygote.ignore(() -> Flux.reset!(rnn)) # we don't need to differentiate this
    y_1 = [rnn(step) for step in x][end]

    Zygote.ignore(() -> Flux.reset!(rnn)) # we don't need to differentiate this
    y_2 = [rnn(step) for step in x][end]

    return (mydense1(y_1) .+ mydense2(y_2)) ./ 2
end

Alternatively, you could package up the (dense + dense) / 2 part into a single layer:

addscale(y_1, y_2) = (y_1 .+ y_2) ./ 2
post_rnn = Parallel(addscale, mydense1, mydense2)
2 Likes