Hi,
I’m playing with Flux.jl and results are really nice.
At the moment, I would like to use the RMSprop solver with Early Stopping. Anyone have experience with implementing those in Flux?
Thanks!
Hi,
I’m playing with Flux.jl and results are really nice.
At the moment, I would like to use the RMSprop solver with Early Stopping. Anyone have experience with implementing those in Flux?
Thanks!
OK, I’m going to reply to the second part of my question myself: just use a for
loop.
I was not sure that for
loops were the way to go as nothing is mentioned in Flux docs about them but they seem to work well.
For those interested, you can grab the loss for the train and valid sets using the callbacks, and then set a condition for early stopping as a function of the losses. A dummy example of the code (using 5-fold CV with MLDataUtils) looks something like:
fold_select = 1
early_stop = 0
for epoch_idx in 1:nb_epochs
train, valid = folds[fold_select] # selection of the datasets
evalcb = () -> (push!(record_loss_n_train, loss(train).data),
push!(record_loss_n_valid, loss(valid).data))
Flux.train!(loss, params(m), train, opt, cb = throttle(evalcb, 1))
fold_select += 1 # for selecting the K-fold between 1 and 5
if fold_select >= 6
fold_select = 1
end
# for early stop
if record_loss_n_valid[epoch_idx] > record_loss_n_valid[epoch_idx-1]
early_stop += 1
end
if early_stop > 100
break
end
end
I think the “Flux-iest” way to implement this would be have your training callback function check for the early stopping condition, and then have it call Flux.stop()
(link) when when the condition is met and you want to break out of the loop.
I’m surprised that this feature isn’t in the documentation! No wonder it wasn’t obvious to you I’ll plan to put together a PR this weekend to update the docs, unless somebody else beats me to it.
Indeed, this is very good. This allows early stopping without a loop. However I must say that using a loop works well and give a lot of low-level control…
AND RMSprop and tones of other solvers are available. In the source code, again. Not indicated in the docs.
We should also open pull requests for adding RMSprop and the other solvers in the docs.
However I must say that using a loop works well and give a lot of low-level control…
Yeah, definitely! I like that about Flux- it gives you a huge amount of flexibility.
Hi,
I would like to create a callback, which to something similiar to the EarlyStopping and restore weights callback of keras in Flux (saving the best model according to the test data and stopping the optimization if no improvement since defined counter), but loss_test_tmp
variable isn´t known by the callback. Actually I don´t understand why the callback function knows my test data X_test
and y_test
and not loss_test_tmp
. I define all of them in the same scope in a script.
function evalcb()
@show loss_test = loss(X_test, y_test)
if loss_test < loss_test_tmp
loss_test_tmp = loss_test
ct = 0
BSON.@save String(@__DIR__) * raw"model-checkpoint.bson" m
else
ct += 1
end
if ct > patience
println("Optimization will be stopped!")
Flux.stop()
end
end
I have done it with a custom train- function, but it is much slower.
I suggest you define your own training loop, it won’t been any slower than using Flux.train
with callbacks, and probably much clearer. Here some examples from the model-zoo
https://github.com/FluxML/model-zoo/blob/master/vision/lenet_mnist/lenet_mnist.jl
Thanks. I will have a look
It would still be nice to know why the callback approach isn’t working; any insights here?