Problem on model and gradient descend in Flux

Let me add that often it is more convenient to wrap the entire model inside a custom struct and define a forward pass instead of using Chain and Parallel:

using Flux

struct Model{D1, D2}
    dense1::D1
    dense2::D2
end

Flux.@layer Model

function Model()
    return Model(
        Dense(4 => 5),
        Dense(4 => 6))
end

function (m::Model)(x)
    x1 = m.dense1(x)
    x2 = m.dense2(x)
    return cat(x1, x2, dims=1)
end

# x = rand(Float32, 4) # with no batch dimension
# y = rand(Float32, 11)
x = rand(Float32, 4, 5) # 5 examples in a batch
y = rand(Float32, 11, 5)
loss(model, x, y) = Flux.mse(model(x), y)

model = Model()
opt_state = Flux.setup(AdamW(), model)
g = gradient(model -> loss(model, x, y), model)[1]
Flux.update!(opt_state, model, g)
1 Like