Save best model in FluxTraining.jl

Cheers, is there a built-in way with fit! to save the model parameters (or model state) after each epoch? What about saving the best model only?

Thanks in advance.

An example based on the document:

using Flux, JLD2

x = rand32(10,100)
y = rand32(1,100)
m = Chain(Dense(10 => 5, relu), Dense(5 => 2), Dense(2=>1))
opt_state = Flux.setup(Adam(), m)

jld = jldopen("model-checkpoint.jld2", "w")
jld["loss"] = Flux.Losses.mse(m(x),y)
jld["model_state"] = Flux.state(m)

for epoch in 1:10
    loss, grads = Flux.withgradient(m->Flux.Losses.mse(m(x),y), m)
    if loss < jld["loss"]
        delete!(jld, "loss")
        jld["loss"] = loss
        delete!(jld, "model_state")
        jld["model_state"] = Flux.state(m)
        @info "Better model found; overwrote the model checkpoint"
    end
    Flux.update!(opt_state, m, grads[1])
end

close(jld)

PS: it would look nicer if this issue gets implemented.

1 Like

Thanks for prompt reply, and my apologies for not being clear. I did not mean saving with the Flux package, but with the instruction fit! from FluxTraining.jl.

Thanks again for the prompt help!

Taking EarlyStopping as an example, it seems you need to create struct YourCallBack <: AbstractCallback ... end and extend on(::EpochEnd, phase::Phase, cb::YourCallBack, learner) etc. I suggest just use Flux.jl if you need that much degree of flexibility, though.

1 Like

Thanks for feedback. I’ve adopted a solution where all callbacks but log and metrics are explicit within the loop. That also solved an issue where early stopping is not currently exiting gracefully from FluxTraining https://github.com/FluxML/FluxTraining.jl/issues/159

trainlearner = Learner(model, lossfn;
                      optimizer=opt,
                      callbacks=[log_cb],  # only log callback
)
validlearner = Learner(model, lossfn;
                      callbacks=[metrics, log_cb]  # only log and metrics callbacks
)

for epoch in 1:epochs
        epoch!(trainlearner, TrainingPhase(), trainset)
        epoch!(validlearner, ValidationPhase(), validset)
        # all other callbacks added here, such as save best model, etc
end
1 Like