Training gets differents results when using Flux.train() inside function

I am training a NN with some synthetic data. I was running everything on REPL and things were going on the right direction, but when I tried to wrap my code inside a function() to use custom training loops, my model(x) started to return the same output every iteartion, as it would have stopped updating.

My code looks like:

using Flux;
using Flux.Optimise: update!;
using Flux: normalise;
using Flux: onecold;
using Flux: onehotbatch;
using Flux: @epochs;
using Flux: throttle;

Random.seed!(125);

ep_max  = 2;  # number of epochs
batch   = 100;   # batch size for training
lr      = 0.001  # learning rate
spt     = 0.01;   # Split ratio: define % to be used as Test data
opt     = ADAM(lr, (0.9, 0.8)); # Optimizer
time_show = 5;
dat_groups = 1:10;
dat_num    = 100;
creating   = false;
reading    = true;

if creating
  data_creator(dat_groups,dat_num); # I create data and store it
end
if reading
  xtrain,ytrain = data_reader(dat_groups); #reads data
end

# batching data
datatrain, datatest = getdata(xtrain',ytrain',spt,batch); # DataLoader function in here                                          

xtr,ytr = recoverdata(datatrain); # recovering training data, to be used if needed
xts,yts = recoverdata(datatest);  # recovering test data, to be used if needed
m = layers(size(xtr,1),size(ytr,1)); # creates layers (6 layers, tanh)
ps = Flux.params(m);                 # initialize parameters

trainmode!(m,true)
evalcb = () -> @show(loss_all(datatrain, m))

for i = 1:ep_max # run ep_max times for a single batch
  println()
  println("**************")
  println(i)
  println()
  Flux.train!(loss, ps, datatrain, opt, cb = throttle(evalcb, time_show));
  println()
  @show accuracy(datatest, m)
end
function accuracy(dataloader, model)
    acc = 0
    for (x,y) in dataloader
        @show model(x)      # **HERE : MODEL ALLWAYS THE SAME** 
        acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2)
    end
    acc/length(dataloader);
end
loss(x,y)   = Flux.mse(m(x),y);

When I run this code, for instance, for 2 epoch iterations, it gives:

julia> include("main.jl")
size of X data is :(1000, 63)
size of Y data is :(1000, 6)

[ Info: Batching data...
[ Info: splitting into 990.0, 10.0

[ Info: Batching train data...
[ Info: Batching test data...
┌ Warning: Number of data points less than batchsize, decreasing the batchsize to 10
└ @ Flux.Data ~/.julia/packages/Flux/Fj3bt/src/data/dataloader.jl:64
[ Info: layers created....

**************
1

loss_all(datatrain, m) = 0.3405524244181338

mod_x = Float32[0.21500134 0.25191692     ........]
acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2) = 0.2
accuracy(datatest, m) = 0.2

**************
2

loss_all(datatrain, m) = 0.22800623235686412

mod_x = Float32[0.45594802 0.4247107      ........]
acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2) = 0.3
accuracy(datatest, m) = 0.3

In this experiment (on REPL) everything runs Ok, and it converges:

  • cost function loss_all(datatrain, m) changes…
  • accuracy(datatest, m) changes…
  • and model(x) changes, so it is updating…

On the other hand, if I put all the code (main.jl) above inside a function like this:

function all_the_code()
    ....
    ...
    for i = 1:ep_max # run ep_max times for a single batch
      println()
      println("**************")
      println(i)
      println()
      Flux.train!(loss, ps, datatrain, opt, cb = throttle(evalcb, time_show));
      println()
      @show accuracy(datatest, m)
    end
return
end

I get

julia> all_the_code()
size of X data is :(1000, 63)
size of Y data is :(1000, 6)


[ Info: Batching data...
[ Info: splitting into 990.0, 10.0

[ Info: Batching train data...
[ Info: Batching test data...
┌ Warning: Number of data points less than batchsize, decreasing the batchsize to 10
└ @ Flux.Data ~/.julia/packages/Flux/Fj3bt/src/data/dataloader.jl:64
[ Info: layers created....

**************
1

loss_all(datatrain, m) = 0.36721910246630546

mod_x = Float32[-0.23175366 -0.0057259724    ......]
acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2) = 0.2
accuracy(datatest, m) = 0.2

**************
2

loss_all(datatrain, m) = 0.3670469086875309

model(x) = Float32[-0.23175366 -0.0057259724    ......] #
acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2) = 0.2
accuracy(datatest, m) = 0.2

So, you see, the model is not updating:

  • Cost function loss_all(datatrain, m) allways stays the same (aprox 0.36). No matter how many iterations.
  • accuracy(datatest, m) = 0.2 stays the same, on every iteration…
  • model(x) = Float32[-0.23175366 -0.0057259724 ......] # stays the same…

What is going on? Does anybody run into this same issue? Thaks and stay safe!

I would hate to stay on REPL. I always try to avoid the use of global variables, so I need to go inside a function(), but I can’t go blind not knowing why this is happening.

  • Assumptions

  • could it be related to the CPU () function?

  • Note that I am not wrapping the training loop, but the entire function, could this be the issue?.

PD: I am using the same data for both experiments (my seed is the same, so that is Ok)…

any help is welcome, I’ve been stuck here for some time now… :frowning_face:

Probably has to do with your loss(x, y) definition. I suspect the m in the function body is not referencing the m created inside all_the_code and instead referencing some other m in global scope. In any case, it is better to explicitly pass in the model to the loss function, because depending on the scoping of m, you maybe using a global variable which can cause performance issues (and bugs like this one!). Instead you should define

loss(x, y, m) = Flux.mse(m(x), y)

then when you call train!, you can use a closure over m:

Flux.train!((x, y) -> loss(x, y, m), ps, datatrain, opt, cb = throttle(evalcb, time_show))

This will close over the m in the same scope as where Flux.train! was called, so unless you do something really weird, it should be referencing the m you expect.

@darsnack, You nailed it. Thank you very much. Now it run like it’s supposed to!.