Let me add that often it is more convenient to wrap the entire model inside a custom struct and define a forward pass instead of using Chain
and Parallel
:
using Flux
struct Model{D1, D2}
dense1::D1
dense2::D2
end
Flux.@layer Model
function Model()
return Model(
Dense(4 => 5),
Dense(4 => 6))
end
function (m::Model)(x)
x1 = m.dense1(x)
x2 = m.dense2(x)
return cat(x1, x2, dims=1)
end
# x = rand(Float32, 4) # with no batch dimension
# y = rand(Float32, 11)
x = rand(Float32, 4, 5) # 5 examples in a batch
y = rand(Float32, 11, 5)
loss(model, x, y) = Flux.mse(model(x), y)
model = Model()
opt_state = Flux.setup(AdamW(), model)
g = gradient(model -> loss(model, x, y), model)[1]
Flux.update!(opt_state, model, g)