In MLJ, what does `fit!` do exactly?

A query posted on slack:

Hello! I’m using MLJ to train a classifier but I’m not sure to correctly understand how successive trains “stack”, i.e. what happens when we call fit!() several times (it says “machine has been trained x times”, but does it try to improve the previous fitting (like a “gradient descent”)? Can we overfit a model by calling fit!() many times? Is there an interest to call fit!() several times on different data (like when we are performing a k-fold cross validation)?Some context: in our case, we noticed classifier accuracies differed (with the same train and test split), and we were trying to train x different classifiers and select the top 25% of these, but EnsembleModel produced a weird result (it often predicts every test samples as from the same class, and provides uninteresting accuracy otherwise)

1 Like

I believe the Machines
section of the manual addresses at least some of this. But let me
re-iterate to hopefully sort out possible points of confusion. I will
return to the poster’s specific questions after this more
general explanation. Suggestion to improve the documentation most
welcome!

In a normal MLJ workflow one constructs a machine with
mach = machine(model, X, y) and the data X, y should not be
mutated. However, mutation of model (whose fields are the model
“hyper-parameters”) is allowed. The outcome of fit!(mach) then
depends only on model ( i.e., on the values of those
hyperparameters). The call fit!(mach) is not allowed to mutate
model, so if a user does not change model, then the outcome
(learned parameters) of a second fit! call is the same as the
first. (Indeed, if model is unchanged, the lower level
MLJModelInterface.fit function is never even called.)

If your model is iterative, with an iteration parameter called
epochs, say, then increasing the number of epochs in the model means
a new fit! call has a different outcome:

Clf = @load NeuralNetworkClassifier
model = Clf(epochs=10)
mach = machine(model, X, y)
fit!(mach) # train for 10 epochs
model.epochs = 12
fit!(mach) # train for 12 epochs total

However, whether training begins from scratch (“cold restart”) or
simply adds the necessary number of iterations to a partly trained
model (“warm restart”) depends on the implementation. Most packages
implementing MLJ’s model interface for an iterative model implement
the method MLJModelInterface.update in addition to
MLJModelInterface.fit, as a way of buying into “warm
restart”. (However, in principle, other kinds of hyper-parameter
changes may trigger a warm restart as well.)

Another way to change the outcome of fit! is to provide the
keyword rows=..., which specifies what “view” of the data to train on,
as in fit!(mach, rows=1:100). Whenever the view changes, a fit!
call will detect this an make sure to retrain the model.

There is other stuff going on behind the scenes, such as data caching,
but that’s another story.

Returning to the specific questions in the post:

when we call fit!() several times (it says “machine has been trained x times”, but does it try to improve the previous fitting (like a “gradient descent”)?

No. Not unless you specifically increase the iteration parameter.

Can we overfit a model by calling fit!() many times?

No, not unless you are also changing some hyperparameter (eg, regularisation parameter or iteration parameter)

Is there an interest to call fit!() several times on different data (like when we are performing a k-fold cross validation)?

Yes, you call fit on different views of the data using the keywork
rows=... as explained above. Ordinarily however, there is boiler
plate code for this kind of thing. See the Evaluating Model
Performance
section of the manual.

More detail

A simplified version of the fit! code is presented in the
Internals
section of the manual. The fit! method calls fit_only!, which
outside of learning networks is essentially the same thing. The
detailed logic (when call MLJModelInterface.fit and when to call
MLJModelInterface.update and when to skip training altogether) is
called in is detailed in the fit_only doc-string, which I quote at
the end.

I hope this helps.


MLJBase.fit_only!(mach::Machine; rows=nothing, verbosity=1, force=false)

Without mutating any other machine on which it may depend, perform one of
the following actions to the machine mach, using the data and model
bound to it, and restricting the data to rows if specified:

  • Ab initio training. Ignoring any previous learned parameters and
    cache, compute and store new learned parameters. Increment mach.state.

  • Training update. Making use of previous learned parameters and/or
    cache, replace or mutate existing learned parameters. The effect is
    the same (or nearly the same) as in ab initio training, but may be
    faster or use less memory, assuming the model supports an update
    option (implements MLJBase.update). Increment mach.state.

  • No-operation. Leave existing learned parameters untouched. Do not
    increment mach.state.

Training action logic

For the action to be a no-operation, either mach.frozen == true or
or none of the following apply:

  • (i) mach has never been trained (mach.state == 0).

  • (ii) force == true.

  • (iii) The state of some other machine on which mach depends has
    changed since the last time mach was trained (ie, the last time
    mach.state was last incremented).

  • (iv) The specified rows have changed since the last retraining and
    mach.model does not have Static type.

  • (v) mach.model has changed since the last retraining.

In any of the cases (i) - (iv), mach is trained ab initio. If only
(v) fails, then a training update is applied.

To freeze or unfreeze mach, use freeze!(mach) or thaw!(mach).

Implementation detail

The data to which a machine is bound is stored in mach.args. Each
element of args is either a Node object, or, in the case that
concrete data was bound to the machine, it is concrete data wrapped in
a Source node. In all cases, to obtain concrete data for actual
training, each argument N is called, as in N() or N(rows=rows),
and either MLJBase.fit (ab initio training) or MLJBase.update
(training update) is dispatched on mach.model and this data. See the
“Adding models for general use” section of the MLJ documentation for
more on these lower-level training methods.

5 Likes