Hi, I am trying to train a custom model with a list of Flux models as follows
using Flux,Random,Optimisers,Statistics
mutable struct GlobalModel
subnets
function GlobalModel()
model1=Chain(Dense(3=>3,bias=false,init=rand))
model2=Chain(Dense(3=>1,bias=false,init=rand))
subnets=[model1,model2]
new(subnets)
end
end
function call_train(glob::GlobalModel,inputs)
return glob.subnets[2](glob.subnets[1](inputs[1]))
end
(glob::GlobalModel)(inputs) =call_train(glob,inputs)
Flux.@functor GlobalModel
function loss(y_terminal,inputs)
delta = (y_terminal.-sum(inputs[1],dims=1)).^2
return mean(delta)
end
function train!(glob::GlobalModel)
#opt = Optimisers.Adam(0.01)
#opt_state = Optimisers.setup(opt, glob)
optim = Flux.setup(Flux.Adam(0.001), glob)
my_log = []
for epoch in 1:2000
input=[randn((3,5)),randn((2,5))]
val, grads = Flux.withgradient(glob) do m
result = m(input)
loss(result, input)
end
if epoch%500==0
inp=[randn((3,5)),randn((2,5))]
println("Epoch ",epoch, " losses ",loss(glob(inp),inp))
end
Flux.update!(optim, Flux.params(glob), grads)
println(Flux.params(glob))
end
end
glob=GlobalModel();
train!(glob)
However, is seems that the update! function is not updating the parameters in the glob struct, as the parameters remain the same in all epochs. How can I update the glob struct in every step?