Saving multiple MLJ machine to a single file?

Is it possible to save/load multiple trained MLJ machine to a single file to ease file management? Using JLD2 seems to generate a huge file (40GB using JLD2 compared to 40MB when using MLJ.save).

Thanks for posting this question.

You can serialise however you want, so long as you appropriately preprocess (apply serializable) and post-process (apply restore!) as described here.

So, for bundling everything in one file, you can do something like this:

using MLJ

# should not be needed after https://github.com/alan-turing-institute/MLJ.jl/issues/975
import MLJBase: restore!, serializable


X, y = @load_iris

KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels

machs = map(2:10) do K
    model = KNNClassifier(; K)
    mach = machine(model, X, y) |> fit!
end

serializable_machs = serializable.(machs)

using JLSO
JLSO.save("machines.jlso", :machines => serializable_machs)

loaded_machs = JLSO.load("machines.jlso")[:machines]
restore!.(loaded_machs)

julia> foreach(loaded_machs) do mach
       loss = round(log_loss(predict(mach, X), y) |> mean, sigdigits=3)
       println("K=$(mach.model.K) \t traing_loss=$loss")
       end
K=2      traing_loss=0.0277
K=3      traing_loss=0.0494
K=4      traing_loss=0.0604
K=5      traing_loss=0.0566
K=6      traing_loss=0.0644
K=7      traing_loss=0.074
K=8      traing_loss=0.072
K=9      traing_loss=0.0693
K=10     traing_loss=0.0729

Does this address your issue?

2 Likes

Thanks. Do you think serializable should be added to MLJ cheatsheat?

Mmm. Not sure. I imagine the most common workflow for users is the simplified workflow MLJ.save("my_machine.jls", mach) ... machine("my_machine.jls") which is already in the cheatsheet.

Returning to your earlier comment

Using JLD2 seems to generate a huge file (40GB using JLD2 compared to 40MB when using MLJ.save).

I’m assuming this is because you did not use serializable, to remove training data (among other things). If not, then this needs investigation. (The JLS-only simplified workflow takes care of this automatically.)

1 Like

Is three a reliable way to compare (==) machines? I saved using JLSO and loaded to compare the machines before saving and loaded. all == operators seems to give false, while a look at fitted_params(mach) seems equal, and == on fitted_params also gives false.

You cannot presently use == for machines to conclude that two machines give the same predictions (transformations, etc). Currently the model API makes no assumption about the meaning of fitresult1 == fitresult2 for the learned parameters fitresult output by fit(::Model, ...), so even if we overloaded == for machines, that probably won’t give you what you want in all cases. If you have strong use-case for introducing a stronger requirement in the API, feel free to raise an issue at MLJModelInterface.jl. But as it would be some work to ensure all model implementations comply with the stricter requirement, that could take some time.

1 Like