Yeah, I think it makes sense for the clustering to happen in fit
. I looked at scikit-learn, and that appears to be what they do. In scikit-learn, predict
methods are only provided for the clusterers that can label new points, so KMeans
has a predict
method and DBSCAN
does not have a predict
method.
Also, in the discussion above I was hung up on needing to have two separate types for each model, so that the fit
signature would look like this:
fit(m::MyModelParams, X, y) :: MyModel
However, it has finally occurred to me that we can just use one immutable struct where for an untrained model the fields that represent fitted parameters are set to nothing (either by making the field parametric or making it a Union{T, Nothing}
). For example, the code for RandomForestRegressor
could look like this:
struct RandomForestRegressor
n_trees::Int
min_samples_leaf::Int
trees::Union{Vector{DecisionTree}, Nothing}
oob_score::Union{Float64, Nothing}
end
function LearnAPI.fit(m::RandomForestRegressor, X, y)
# ...
# trees = ...
# oob_score = ...
RandomForestRegressor(
m.n_trees,
m.min_samples_leaf,
trees,
oob_score
)
end
Here’s some more code that demonstrates what the clustering models could look like:
Clustering models
module LearnAPI
function fit end
function predict end
function cluster_lables end
end
struct DBSCAN
eps::Float64
min_samples::Int
centers::Union{Vector{Vector{Float64}}, Nothing}
labels::Union{Vector{Int}, Nothing}
end
function LearnAPI.fit(m::DBSCAN, X)
# ...
# centers = ...
# labels = ...
DBSCAN(m.eps, m.min_samples, centers, labels)
end
# No predict method for DBSCAN.
function LearnAPI.cluster_labels(m::DBSCAN)
if isnothing(m.labels)
throw(ArgumentError("DBSCAN model not fit yet."))
end
m.labels
end
# More specific extraction functions are probably not
# defined in LearnAPI.
function cluster_centers(m::DBSCAN)
if isnothing(m.centers)
throw(ArgumentError("DBSCAN model not fit yet."))
end
m.centers
end
struct KMeans
n_clusters::Int
centers::Union{Vector{Vector{Float64}}, Nothing}
labels::Union{Vector{Int}, Nothing}
end
function LearnAPI.fit(m::KMeans, X)
# ...
# centers = ...
# labels = ...
KMeans(m.n_clusters, centers, labels)
end
function LearnAPI.predict(m::KMeans, Xnew)
# Return cluster labels for new data.
end
function LearnAPI.cluster_labels(m::KMeans)
if isnothing(m.labels)
throw(ArgumentError("KMeans model not fit yet."))
end
m.labels
end
# More specific extraction functions are probably not
# defined in LearnAPI.
function cluster_centers(m::KMeans)
if isnothing(m.centers)
throw(ArgumentError("KMeans model not fit yet."))
end
m.centers
end