Train flux struct with list of models

Hi, I am trying to train a custom model with a list of Flux models as follows

using Flux,Random,Optimisers,Statistics

mutable struct GlobalModel
    subnets
    function GlobalModel()
        model1=Chain(Dense(3=>3,bias=false,init=rand))
        model2=Chain(Dense(3=>1,bias=false,init=rand))
        subnets=[model1,model2]
        new(subnets)
    end
end

function call_train(glob::GlobalModel,inputs)
    return glob.subnets[2](glob.subnets[1](inputs[1]))
end

(glob::GlobalModel)(inputs) =call_train(glob,inputs)
Flux.@functor GlobalModel 

function loss(y_terminal,inputs)
    delta = (y_terminal.-sum(inputs[1],dims=1)).^2
    return mean(delta)
end

function train!(glob::GlobalModel)
    #opt = Optimisers.Adam(0.01)
    #opt_state = Optimisers.setup(opt, glob)
    optim = Flux.setup(Flux.Adam(0.001), glob)
    
    my_log = []
    for epoch in 1:2000
        input=[randn((3,5)),randn((2,5))]
        val, grads = Flux.withgradient(glob) do m
        result = m(input)
        loss(result, input)
        end

        if epoch%500==0
            inp=[randn((3,5)),randn((2,5))]
            println("Epoch ",epoch, " losses ",loss(glob(inp),inp))
        end
        Flux.update!(optim, Flux.params(glob), grads)
        println(Flux.params(glob))
    end
end

glob=GlobalModel();
train!(glob)


However, is seems that the update! function is not updating the parameters in the glob struct, as the parameters remain the same in all epochs. How can I update the glob struct in every step?

This ought to be an error, as it should only accept field names of the struct. But there’s only one field, you want just @functor GlobalModel .

Yes, you’re right, my bd. However, removing it didnt’ do anything.

You’re mixing different parameter handling styles here, which is why it doesn’t work. If you use setup and pass a model to (with)gradient, don’t use params (and vice versa). Change this to:

        Flux.update!(optim, glob, grads)

And things should work.

Thanks! I tried it before, but when I changed Flux.params(glob) to just glob, I got an error type Tuple has no field subnets

This is when you look at grads to make sure it’s what you’d expect. Note that grads should be a tuple (one element for each argument to withgradient), so you need to extract the first element to get at the gradients of glob.

Sorry! I also tried it, but the error was the following

MethodError: no method matching GlobalModel(::Vector{Chain{Tuple{Dense{typeof(identity), Matrix{Float64}, Bool}}}}) Closest candidates are: GlobalModel() at In[4]:5

Move this inner constructor outside of the definition of GlobalModel. Then instead of new, just call GlobalModel(subnets). The general recommendation is to avoid inner constructors unless you need them, and in this case you don’t :slight_smile:

You’re my hero :heart_eyes: