Batch size and early stopping in Flux

I´m starting to work with Flux and I have a neural network with 63 entries, 2 hidden layers and 1 exit. Each entry and exit is a vector of 113344 elements, from which 78801 elements belong to the training set and the remainder, to the test set.

I trained the network with 5000 epochs, and it took around 8 hours to complete the training. So, in order to decrease this time, I thought about using early stopping and defining batch sizes, but I´m not sure how to do it. At this moment, I´m using the following code to train the network:

function loss_training(x_train::Array{Float64,2}, y_train::Array{Float64,2})
model = Chain(Dense(63,63,sigmoid),Dense(63,63,sigmoid),Dense(63,1,sigmoid))
loss(x,y) = Flux.mae(model(x),y)
ps = params(model)
dataset = [(x_train’,y_train’)]
opt = ADAGrad()
cb = () -> println(loss(x_train’,y_train’))
Flux.@epochs 5000 Flux.train!(loss,ps,dataset,opt,cb=cb)
y_hat = model(x_train’)’
return y_hat, model

y_hat, model = @time loss_training(x_train, y_train)

y_test_hat = @time model(x_test’)’

@Mariana I do not know about the early stopping, but for batch size, I recommend you to use DataLoader (in Flux.Data). With the parameters batchsize, and shuffle it is very simple to use:

        data = Flux.Data.DataLoader(X, Y, batchsize=32, shuffle=true)
        opt = ADAM()
        loss(x, y) = Flux.mse(model(x), y)
        ps = Flux.params(model)
        evalcb() = @show loss(X,Y)
        for epoch in 1:8
            println("Epoch $epoch")
            time = @elapsed train!(loss, ps, data, opt, cb=throttle(evalcb, 3))
            println("Echo $epoch: $time secs")
            @show epoch, loss(X,Y)

The code is an example, it is not complete.
If you have more problems, do not hesitate in ask again.

@dmolina, thank you very much for your answer! I tried a code similar to yours in my problem and it worked! One question, though: what does cb=throttle(evalcb,3) do in the code?

@Mariana Throttle is a function in Flux that allow you to run a function, evalcb in this case, each X seconds. Thus, each 3 seconds, more or less, evalcb is run. This is good to show the improving but not for each iteration. See the documentation of the function. Also, there is a Flux.stop() that also could be useful.

@dmolina, thank you very much!!

Please use backticks ( ` ) around your code in the future to make it easier to read for us. See PSA: make it easier to help you

1 Like