RMSprop and Early stopping in Flux.jl

Hi,

I’m playing with Flux.jl and results are really nice.

At the moment, I would like to use the RMSprop solver with Early Stopping. Anyone have experience with implementing those in Flux?

Thanks!

OK, I’m going to reply to the second part of my question myself: just use a for loop.

I was not sure that for loops were the way to go as nothing is mentioned in Flux docs about them but they seem to work well.

For those interested, you can grab the loss for the train and valid sets using the callbacks, and then set a condition for early stopping as a function of the losses. A dummy example of the code (using 5-fold CV with MLDataUtils) looks something like:

fold_select = 1
early_stop = 0
for epoch_idx in 1:nb_epochs
    
    train, valid = folds[fold_select] # selection of the datasets
     
    evalcb = () -> (push!(record_loss_n_train, loss(train).data),
    push!(record_loss_n_valid, loss(valid).data))

    Flux.train!(loss, params(m), train, opt, cb = throttle(evalcb, 1))
    
    fold_select += 1 # for selecting the K-fold between 1 and 5
    if fold_select >= 6
        fold_select = 1
    end

    # for early stop
    if record_loss_n_valid[epoch_idx] > record_loss_n_valid[epoch_idx-1]
         early_stop += 1
    end
    if early_stop > 100
        break
    end
end
1 Like

I think the “Flux-iest” way to implement this would be have your training callback function check for the early stopping condition, and then have it call Flux.stop() (link) when when the condition is met and you want to break out of the loop.

I’m surprised that this feature isn’t in the documentation! No wonder it wasn’t obvious to you :slight_smile: I’ll plan to put together a PR this weekend to update the docs, unless somebody else beats me to it.

6 Likes

Indeed, this is very good. This allows early stopping without a loop. However I must say that using a loop works well and give a lot of low-level control…

AND RMSprop and tones of other solvers are available. In the source code, again. Not indicated in the docs.

We should also open pull requests for adding RMSprop and the other solvers in the docs.

1 Like

However I must say that using a loop works well and give a lot of low-level control…

Yeah, definitely! I like that about Flux- it gives you a huge amount of flexibility.

Hi,

I would like to create a callback, which to something similiar to the EarlyStopping and restore weights callback of keras in Flux (saving the best model according to the test data and stopping the optimization if no improvement since defined counter), but loss_test_tmp variable isn´t known by the callback. Actually I don´t understand why the callback function knows my test data X_test and y_test and not loss_test_tmp. I define all of them in the same scope in a script.

function evalcb()
    @show loss_test = loss(X_test, y_test)
    if loss_test < loss_test_tmp
        loss_test_tmp = loss_test
        ct = 0
        BSON.@save String(@__DIR__) * raw"model-checkpoint.bson" m
    else
        ct += 1
    end
    if ct > patience
        println("Optimization will be stopped!")
        Flux.stop()
    end
end

I have done it with a custom train- function, but it is much slower.

I suggest you define your own training loop, it won’t been any slower than using Flux.train with callbacks, and probably much clearer. Here some examples from the model-zoo

https://github.com/FluxML/model-zoo/blob/master/vision/lenet_mnist/lenet_mnist.jl

2 Likes

Thanks. I will have a look

It would still be nice to know why the callback approach isn’t working; any insights here?