[ANN] LearnAPI.jl - Proposal for a basement-level machine learning API

The new proposal with raw_fit and Model{S,O,P} still seems too complicated to me. And it’s a leaky abstraction. Suppose I wrote a random forest package, and I want to implement a trees method that acts on the output of fit and returns the ensemble of fitted decision trees. It would have to look something like this:

trees(m::Model{RandomForestRegressor}) = m.params.trees

In other words, the implementation of trees depends on the Model type having a params field. But we don’t usually make struct fields part of APIs, especially APIs that are meant to be very generic and ecosystem-wide. I suppose you could add a params method to extract params from a Model object (ignoring the fact that you proposed the params function for a different purpose), but it all seems more complicated and less generic than necessary. Anyways, as a package developer I don’t want to have to reach through an external type imposed by LearnAPI in order to access the internals of an object that I’ve implemented myself.

I’m going to write out an interface proposal here, because the whole thing is actually rather short and simple (ignoring traits, target proxies, sci-types, etc).

LearnInterface.jl

module LearnInterface

export fit, predict, minimize

"""
    fit(params, X, y)

Run a learning algorithm on the features `X` and target `y` with the
algorithm hyperparameters specified in `params`. Returns an object
`model` that can make predictions on new features `Xnew` by calling
`predict(model, Xnew)`.
"""
function fit end

"""
    predict(model, X)

Make predictions with `model` on the features `X`. The object
`model` is the output of a call to `fit`.
"""
function predict end

"""
    minimize(model)

Return a minimal version of `model` suitable for serialization.
`minimize` satisfies the following contract:

    predict(model, X) == predict(minimize(model), X)

LearnInterface.jl provides the following default implementation:

    minimize(model) = model
"""
minimize(model) = model

end

So, by default, an implementer of a LearnInterface compatible model needs to implement two types and two methods:

  • A type for the params object (just the hyperparameters)
  • A type for the model object
  • fit
  • predict

If the implementer chooses to customize serialization, they will need to implement the following:

  • A type for the output of minimize(model)
  • Another method of predict that dispatches on the output of minimize(model)

Note that in this interface there is no need for a params method or a report method.

This seems to me like a very simple and intuitive interface, and it is very generic.

9 Likes