Unable to find `fit` in `CatBoost.MLJCatBoostInterface

Hi!

I have developped a ML workflow using MLJ with some data preparation, feature engineer, cross-val and tuning (evaluate! function). Up to now, I used LightGBM (it worked fine), but since I have categorical variables, I wanted to test the performance with CatBoost.

I have an issue when using it. Despite having added Catboost and imported it, I keep having this message when the evaluation of the models starts.

Evaluating over 81 metamodels:   1%[>                        ]  ETA: 0:13:12┌ Error: Problem fitting the machine machine(DeterministicTunedModel(model = CatBoostRegressor(iterations = 1000, …), …), …). 
└ @ MLJBase ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:694
[ Info: Running type checks... 
[ Info: Type checks okay. 
ERROR: UndefVarError: `fit` not defined in `CatBoost.MLJCatBoostInterface`
Suggestion: check for spelling errors or missing imports.
Hint: a global variable of this name also exists in StatsBase.
Hint: a global variable of this name may be made accessible by importing Distributions in the current active module Main
Hint: a global variable of this name may be made accessible by importing GLM in the current active module Main
Hint: a global variable of this name may be made accessible by importing MLJModelInterface in the current active module Main
Hint: a global variable of this name may be made accessible by importing MLJBase in the current active module Main
Stacktrace:
  [1] update(mlj_model::CatBoost.MLJCatBoostInterface.CatBoostRegressor, verbosity::Int64, fitresult::PythonCall.Core.Py, cache::@NamedTuple{…}, data_pool::PythonCall.Core.Py)
    @ CatBoost.MLJCatBoostInterface ~/.julia/packages/CatBoost/8tf8r/src/MLJCatBoostInterface.jl:153
  [2] fit_only!(mach::MLJBase.Machine{…}; rows::Vector{…}, verbosity::Int64, force::Bool, composite::Nothing)
    @ MLJBase ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:716
  [3] fit_only!
    @ ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:617 [inlined]
...

Any recommendation to solve it?
Thx

Thanks for reporting this! To better understand the issue, could you provide a Minimal Working Example (MWE) that reproduces the error?

Below is an MWE (extracted out of a heavy ML application). The package versions are reported at the end of the small script.
Thx for your help!

using MLJ
using RDatasets

# Instantiate the model
CatBoostRegressor = @load CatBoostRegressor pkg=CatBoost
ml_model = CatBoostRegressor()
# Pipeline definition
pipe = ml_model

# Range
# LGBMRegressor
hp_ranges = [
    range(pipe, :(iterations), lower = 30, upper = 300, scale=:log),
    range(pipe, :(learning_rate), lower=0.01, upper=0.2, scale=:log),
    range(pipe, :(depth), lower = 3, upper = 10, scale=:log),
    range(pipe, :(subsample), lower = 0.6, upper = 1., scale=:linear),
    range(pipe, :(l2_leaf_reg), lower = 3, upper = 100, scale=:log)
]

# My hyperparams grid
tuned_model = TunedModel(
    model=pipe,
    tuning=Grid(resolution=4),
    ranges=hp_ranges,
    measure=mae,
    train_best=true,
    cache=false,
    compact_history=true,
)

# Prepare the data
y, x = unpack(
    dataset("datasets", "iris"),
    ==(:SepalWidth)
)

# Define machine
mach = machine(
    tuned_model,
    x,
    y,
    cache=false,
    scitype_check_level=2
)

# Evaluate 
MLJ.evaluate!(
    mach,
    measure=mae,
    acceleration=CPU1(),
    verbosity=2,
    per_observation=false
)

# With package versions: 
#   [cbdf2221] AlgebraOfGraphics v0.8.13
#   [c6697862] AzStorage v2.7.0
#   [6e4b80f9] BenchmarkTools v1.5.0
#   [336ed68f] CSV v0.10.15
# ⌃ [13f3f980] CairoMakie v0.12.16
#   [e2e10f9a] CatBoost v0.3.5
#   [324d7699] CategoricalArrays v0.10.8
#   [caabdcdb] Cleaner v1.1.1
#   [944b1d66] CodecZlib v0.7.6
#   [a93c6f00] DataFrames v1.7.0
#   [864edb3b] DataStructures v0.18.20
#   [89f0c457] FilesystemDatastructures v1.1.0
# ⌃ [cd3eb016] HTTP v1.10.12
#   [f7bf1975] Impute v0.6.12
#   [5903a43b] Infiltrator v1.8.3
#   [c8e1da08] IterTools v1.10.0
#   [9da8a3cd] JLSO v2.7.0
#   [682c06a0] JSON v0.21.4
#   [0f8b85d8] JSON3 v1.14.1
#   [7acf609c] LightGBM v1.0.0
#   [e6f89c97] LoggingExtras v1.1.0
#   [add582a8] MLJ v0.20.7
#   [a7f614a8] MLJBase v1.7.0
#   [e80e1ace] MLJModelInterface v1.11.0
#   [ce6b1742] RDatasets v0.7.7
#   [295af30f] Revise v3.6.4
#   [b0e4dd01] RollingFunctions v0.8.1
#   [8523bd24] ShapML v0.3.2
#   [1277b4bf] ShiftedArrays v2.0.0
# ⌃ [2913bbd2] StatsBase v0.34.3
#   [69024149] StringEncodings v0.3.7
#   [bd369af6] Tables v1.12.0
#   [f269a46b] TimeZones v1.19.0
#   [5c2747f8] URIs v1.5.1
#   [ddb6d928] YAML v0.4.12
#   [a5390f91] ZipFile v0.10.1

Thank you for reporting this. Can you verify that everything is working on your end? Here is the PR with the fix:

2 Likes

Hey,

It works perfect now. Thanks for the fix.