Don’t think that will work as expected:
- While you can pass functions to
Chain
they will be opaque, i.e.,Flux
cannot see inside to get parameters. Further, your functionx -> Dense(5, 5)
never calls the dense layer!
Simply usemodel1 = Dense(5 => 5, relu)
orChain(Dense(5 => 5), relu)
instead. - You can combine all your model parts into a single model:
model = Chain(Dense(4 => 5, relu), # your model1 Parallel(tuple, # combine both model outputs into tuple Dense(5 => 6), # model2 Dense(5 => 7))) # model3 # Use as follows ... note that I have changed the dimensions to better understand where each value is coming from batch = rand(4, 8) size.(model(batch)) # will be ((6, 8), (7, 8)) gradient(model) do m m2, m3 = m(batch) loss1 = m2 .- trueX loss2 = m3 .- trueY sum(vcat(loss1, loss2)) end