Yeah, I think it makes sense for the clustering to happen in `fit`

. I looked at scikit-learn, and that appears to be what they do. In scikit-learn, `predict`

methods are only provided for the clusterers that can label new points, so `KMeans`

has a `predict`

method and `DBSCAN`

does not have a `predict`

method.

Also, in the discussion above I was hung up on needing to have two separate types for each model, so that the `fit`

signature would look like this:

```
fit(m::MyModelParams, X, y) :: MyModel
```

However, it has finally occurred to me that we can just use one immutable struct where for an untrained model the fields that represent fitted parameters are set to nothing (either by making the field parametric or making it a `Union{T, Nothing}`

). For example, the code for `RandomForestRegressor`

could look like this:

```
struct RandomForestRegressor
n_trees::Int
min_samples_leaf::Int
trees::Union{Vector{DecisionTree}, Nothing}
oob_score::Union{Float64, Nothing}
end
function LearnAPI.fit(m::RandomForestRegressor, X, y)
# ...
# trees = ...
# oob_score = ...
RandomForestRegressor(
m.n_trees,
m.min_samples_leaf,
trees,
oob_score
)
end
```

Here’s some more code that demonstrates what the clustering models could look like:

##
Clustering models

```
module LearnAPI
function fit end
function predict end
function cluster_lables end
end
struct DBSCAN
eps::Float64
min_samples::Int
centers::Union{Vector{Vector{Float64}}, Nothing}
labels::Union{Vector{Int}, Nothing}
end
function LearnAPI.fit(m::DBSCAN, X)
# ...
# centers = ...
# labels = ...
DBSCAN(m.eps, m.min_samples, centers, labels)
end
# No predict method for DBSCAN.
function LearnAPI.cluster_labels(m::DBSCAN)
if isnothing(m.labels)
throw(ArgumentError("DBSCAN model not fit yet."))
end
m.labels
end
# More specific extraction functions are probably not
# defined in LearnAPI.
function cluster_centers(m::DBSCAN)
if isnothing(m.centers)
throw(ArgumentError("DBSCAN model not fit yet."))
end
m.centers
end
struct KMeans
n_clusters::Int
centers::Union{Vector{Vector{Float64}}, Nothing}
labels::Union{Vector{Int}, Nothing}
end
function LearnAPI.fit(m::KMeans, X)
# ...
# centers = ...
# labels = ...
KMeans(m.n_clusters, centers, labels)
end
function LearnAPI.predict(m::KMeans, Xnew)
# Return cluster labels for new data.
end
function LearnAPI.cluster_labels(m::KMeans)
if isnothing(m.labels)
throw(ArgumentError("KMeans model not fit yet."))
end
m.labels
end
# More specific extraction functions are probably not
# defined in LearnAPI.
function cluster_centers(m::KMeans)
if isnothing(m.centers)
throw(ArgumentError("KMeans model not fit yet."))
end
m.centers
end
```