Problem on model and gradient descend in Flux

Hi. I have a problem on gradient descend in Flux.
Suppose I have a model like this:

model1 = Chain(x->Dense(5,5),relu)
model2 = Chain(x->model1(x), Dense(5,5), vec)
model3 = Chain(x->model1(x), Dense(5,5), vec)

And I wrap them in params:

parameters = Flux.params(model1,model2,model3)

And I use gradient descend:

gs = gradient(parameters) do
    loss1 = model2(x) - trueX
    loss2 = model3(y) - trueY
    loss = loss1 + loss2
end

My target is to learn two things. They are both learnt from model1, but I use model2 to output one thing and model3 to output the other. I am curious whether I can use codes above to realize it.
Thank you!

Don’t think that will work as expected:

  1. While you can pass functions to Chain they will be opaque, i.e., Flux cannot see inside to get parameters. Further, your function x -> Dense(5, 5) never calls the dense layer!
    Simply use model1 = Dense(5 => 5, relu) or Chain(Dense(5 => 5), relu) instead.
  2. You can combine all your model parts into a single model:
    model = Chain(Dense(4 => 5, relu), # your model1
                  Parallel(tuple, # combine both model outputs into tuple
                           Dense(5 => 6),   # model2
                           Dense(5 => 7)))  # model3
    
    # Use as follows ... note that I have changed the dimensions to better understand where each value is coming from
    batch = rand(4, 8)
    size.(model(batch))  # will be ((6, 8), (7, 8))
    
    gradient(model) do m
        m2, m3 = m(batch)
        loss1 = m2 .- trueX
        loss2 = m3 .- trueY
        sum(vcat(loss1, loss2))
    end
    
2 Likes

Thank you!
I understood. Your explanation is quite clear!

May I ask you for another related question.
Suppose I have two separate models. I want to concatenate their outputs and put it into a new model. Codes like this:

model1 = Chain(Dense(4 => 5),vec)
model2 = Chain(Dense(4 => 6),vec)
model3 = Chain(
    x->cat(model1(x),model2(x),dims=1),
    Dense(11 => 11),
    vec
)

Will this code work?
Thanks for your kind and patient reply!

Not if you want to train model1 and model2 as well. Again, the paramaters of them will not be seen inside the function:

julia> model1 = Chain(Dense(4 => 5),vec)
Chain(
  Dense(4 => 5),                        # 25 parameters
  vec,
) 

julia> model2 = Chain(Dense(4 => 6),vec)
Chain(
  Dense(4 => 6),                        # 30 parameters
  vec,
) 

julia> model3 = Chain(
           x->cat(model1(x),model2(x),dims=1),
           Dense(11 => 11),
           vec
       )
Chain(
  var"#7#8"(),
  Dense(11 => 11),                      # 132 parameters
  vec,
) 

# model3 only has the parameters of the last Dense layer
julia> Dense(11 => 11)
Dense(11 => 11)     # 132 parameters

# Use Parallel again to combine the sub-models -- now the parameters are all visible
julia> model3 = Chain(
           Parallel((x,y)->cat(x,y,dims=1), model1, model2),
           Dense(11 => 11),
           vec
       )
Chain(
  Parallel(
    var"#11#12"(),
    Chain(
      Dense(4 => 5),                    # 25 parameters
      vec,
    ),
    Chain(
      Dense(4 => 6),                    # 30 parameters
      vec,
    ),
  ),
  Dense(11 => 11),                      # 132 parameters
  vec,
)                   # Total: 6 arrays, 187 parameters, 1.035 KiB.

PS: Also not sure about the vec at the end of your model. In Flux models usually work on batches of input, i.e.,

julia> m = Dense(4 => 5)
Dense(4 => 5)       # 25 parameters

julia> size(m(rand(4)))  # single input vector
(5,)

julia> size(m(rand(4, 8)))  # batch of 8 inputs
(5, 8)

# model with vec eliminates batch dimension
julia> size(model1(rand(4, 8)))
(40,)
2 Likes

Let me add that often it is more convenient to wrap the entire model inside a custom struct and define a forward pass instead of using Chain and Parallel:

using Flux

struct Model{D1, D2}
    dense1::D1
    dense2::D2
end

Flux.@layer Model

function Model()
    return Model(
        Dense(4 => 5),
        Dense(4 => 6))
end

function (m::Model)(x)
    x1 = m.dense1(x)
    x2 = m.dense2(x)
    return cat(x1, x2, dims=1)
end

# x = rand(Float32, 4) # with no batch dimension
# y = rand(Float32, 11)
x = rand(Float32, 4, 5) # 5 examples in a batch
y = rand(Float32, 11, 5)
loss(model, x, y) = Flux.mse(model(x), y)

model = Model()
opt_state = Flux.setup(AdamW(), model)
g = gradient(model -> loss(model, x, y), model)[1]
Flux.update!(opt_state, model, g)
1 Like

Thanks for detailed instructions!
I think I’ve learnt to operate correctly now.
Much grate for your anwser!

Thanks for your kind instructions!
I tried your code and succeeded. It’s quite logistic compared to my past version.
They are very helpful!
Much grate for your help!