Cheers, is there a built-in way with fit!
to save the model parameters (or model state) after each epoch? What about saving the best model only?
Thanks in advance.
Cheers, is there a built-in way with fit!
to save the model parameters (or model state) after each epoch? What about saving the best model only?
Thanks in advance.
An example based on the document:
using Flux, JLD2
x = rand32(10,100)
y = rand32(1,100)
m = Chain(Dense(10 => 5, relu), Dense(5 => 2), Dense(2=>1))
opt_state = Flux.setup(Adam(), m)
jld = jldopen("model-checkpoint.jld2", "w")
jld["loss"] = Flux.Losses.mse(m(x),y)
jld["model_state"] = Flux.state(m)
for epoch in 1:10
loss, grads = Flux.withgradient(m->Flux.Losses.mse(m(x),y), m)
if loss < jld["loss"]
delete!(jld, "loss")
jld["loss"] = loss
delete!(jld, "model_state")
jld["model_state"] = Flux.state(m)
@info "Better model found; overwrote the model checkpoint"
end
Flux.update!(opt_state, m, grads[1])
end
close(jld)
PS: it would look nicer if this issue gets implemented.
Thanks for prompt reply, and my apologies for not being clear. I did not mean saving with the Flux package, but with the instruction fit!
from FluxTraining.jl
.
Thanks again for the prompt help!
Taking EarlyStopping
as an example, it seems you need to create struct YourCallBack <: AbstractCallback ... end
and extend on(::EpochEnd, phase::Phase, cb::YourCallBack, learner)
etc. I suggest just use Flux.jl
if you need that much degree of flexibility, though.
Thanks for feedback. I’ve adopted a solution where all callbacks but log and metrics are explicit within the loop. That also solved an issue where early stopping is not currently exiting gracefully from FluxTraining https://github.com/FluxML/FluxTraining.jl/issues/159
trainlearner = Learner(model, lossfn;
optimizer=opt,
callbacks=[log_cb], # only log callback
)
validlearner = Learner(model, lossfn;
callbacks=[metrics, log_cb] # only log and metrics callbacks
)
for epoch in 1:epochs
epoch!(trainlearner, TrainingPhase(), trainset)
epoch!(validlearner, ValidationPhase(), validset)
# all other callbacks added here, such as save best model, etc
end