I would like to train a model that consists of two “submodels”, say model1 and model2. The two models share a common LSTM layer and are followed by two separate dense layers. The final output is the average of the two models.
My (failed) attempt:
using Flux
using Flux.Data: DataLoader
using Flux.Optimise: Optimiser
nfeatures = 11
seq_length = 28
nsamples = 1000
opt = ADAM(0.002)
device = cpu
# Generete toy samples
X = rand(Float32, nfeatures, seq_length, nsamples)
Y = sum(X[1, end-1:end, :], dims=1)
train_loader = DataLoader((X, Y), batchsize=64, shuffle=true)
function apply_model(model, x)
model1 = Chain(model[:mycommomlayer], model[:mydense1])
model2 = Chain(model[:mycommomlayer], model[:mydense2])
Flux.reset!(model1)
y_1 = last(map(model1, [view(x, :, t, :) for t in 1:seq_length]))
Flux.reset!(model2)
y_2 = last(map(model2, [view(x, :, t, :) for t in 1:seq_length]))
return (y_1 + y_2)/2
end
model = Chain(mycommomlayer = LSTM(nfeatures, 32),
mydense1 = Chain(Dropout(0.1), Dense(32, 1)),
mydense2 = Chain(Dropout(0.1), Dense(32, 1)))
function loss(ŷ, y)
l = Flux.mse(y, ŷ)
println("loss is $l")
return l
end
ps = Flux.params(model)
for epoch in 1:5
@info(" start epoch $epoch ")
for (x, y) in train_loader
x, y = x |> device, y |> device
gs = Flux.gradient(ps) do
ŷ = apply_model(model, x)
loss(ŷ, y)
end
Flux.Optimise.update!(opt, ps, gs)
end
end
It crashes with
[ Info: start epoch 1
loss is 1.6848714
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any…) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/operators.jl:655
+(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/ksvfu/src/tangent_arithmetic.jl:122
+(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/ksvfu/src/tangent_arithmetic.jl:146
…
Stacktrace:
[1] accum(x::Base.RefValue{Any}, y::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:17
[2] macro expansion
@ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:57 [inlined]
[3] accum_param(cx::Zygote.Context, x::Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Δ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:53
[4] back
@ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:228 [inlined]
[5] #1765#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[6] Pullback
@ ./namedtuple.jl:127 [inlined]
[7] #213
@ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:203 [inlined]
[8] (::Zygote.var"#1754#back#215"{Zygote.var"#213#214"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, typeof(∂(getindex))}})(Δ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[9] Pullback
@ ~/.julia/packages/MacroTools/PP9IQ/src/examples/forward.jl:18 [inlined]
[10] (::Zygote.var"#213#214"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#getindex#149))})(Δ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:203
[11] #1754#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[12] Pullback
@ ~/.julia/packages/MacroTools/PP9IQ/src/examples/forward.jl:18 [inlined]
[13] Pullback
@ ~/.julia/packages/Zygote/cCyLF/src/tools/builtins.jl:15 [inlined]
[14] (::typeof(∂(literal_getindex)))(Δ::NamedTuple{(:cell, :state), Tuple{Nothing, Tuple{Matrix{Float32}, Matrix{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[15] Pullback
@ ~/Library/CloudStorage/Box-Box/MZ/proj/streamflow_ungauged/finetunning_physic/toy.jl:21 [inlined]
[16] (::typeof(∂(apply_model)))(Δ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[17] Pullback
@ ~/Library/CloudStorage/Box-Box/MZ/proj/streamflow_ungauged/finetunning_physic/toy.jl:49 [inlined]
[18] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[19] (::Zygote.var"#94#95"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:357
[20] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:76
[21] top-level scope
@ ~/Library/CloudStorage/Box-Box/MZ/proj/streamflow_ungauged/finetunning_physic/toy.jl:48
Please, how could I define the two models that share the LSTM layer?
Many thanks in advance!