Save best model in FluxTraining.jl

Cheers, is there a built-in way with fit! to save the model parameters (or model state) after each epoch? What about saving the best model only?

An example based on the document:

using Flux, JLD2

x = rand32(10,100)
y = rand32(1,100)
m = Chain(Dense(10 => 5, relu), Dense(5 => 2), Dense(2=>1))
opt_state = Flux.setup(Adam(), m)

jld = jldopen("model-checkpoint.jld2", "w")
jld["loss"] = Flux.Losses.mse(m(x),y)
jld["model_state"] = Flux.state(m)

for epoch in 1:10
    loss, grads = Flux.withgradient(m->Flux.Losses.mse(m(x),y), m)
    if loss < jld["loss"]
        delete!(jld, "loss")
        jld["loss"] = loss
        delete!(jld, "model_state")
        jld["model_state"] = Flux.state(m)
        @info "Better model found; overwrote the model checkpoint"
    Flux.update!(opt_state, m, grads[1])


PS: it would look nicer if this issue gets implemented.

Thanks for prompt reply, and my apologies for not being clear. I did not mean saving with the Flux package, but with the instruction fit! from FluxTraining.jl.

Taking EarlyStopping as an example, it seems you need to create struct YourCallBack <: AbstractCallback ... end and extend on(::EpochEnd, phase::Phase, cb::YourCallBack, learner) etc. I suggest just use Flux.jl if you need that much degree of flexibility, though.

Thanks for feedback. I’ve adopted a solution where all callbacks but log and metrics are explicit within the loop. That also solved an issue where early stopping is not currently exiting gracefully from FluxTraining

trainlearner = Learner(model, lossfn;
                      callbacks=[log_cb],  # only log callback
validlearner = Learner(model, lossfn;
                      callbacks=[metrics, log_cb]  # only log and metrics callbacks

for epoch in 1:epochs
        epoch!(trainlearner, TrainingPhase(), trainset)
        epoch!(validlearner, ValidationPhase(), validset)
        # all other callbacks added here, such as save best model, etc
