How to implement MLJ models properly for `learning_curve!`

I am trying to copy the XGBoost.jl with Crabs tutorial with JLBoost.jl. I think the data is identical and most things work fine, but I can’t get learning_curve! to work. Here is the code.

using Pkg; Pkg.activate(".")
using MLJ, StatsBase, Random, PyPlot, CategoricalArrays, PrettyPrinting, DataFrames
X, y = @load_crabs
X = DataFrame(X)

using XGBoost, MLJ
@load XGBoostClassifier
xgb  = XGBoostClassifier()
xgbm = machine(xgb, X, y)
r = range(xgb, :num_round, lower=10, upper=500)
curve = learning_curve!(xgbm, resampling=CV(),
                        range=r, resolution=25,
                        measure=cross_entropy)


]add https://github.com/xiaodaigh/JLBoost.jl#development

using JLBoost
xgb = JLBoostClassifier()
xgbm = machine(xgb, X, y)
r = range(xgb, :nrounds, lower=10, upper=500)
curve = learning_curve!(xgbm, resampling=CV(),
                        range=r, resolution=25,
                        measure=cross_entropy)

But I am getting this error, and I am not sure where this CrossEntropy comes in? Also, I tried to look through the XGBoost MLJ implementation and I can’t find what I should change to get this to work.


MethodError: no method matching (::MLJBase.CrossEntropy)(::Array{Float64,1}, ::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}})
Closest candidates are:
  Any(!Matched::AbstractArray{#s160,1} where #s160<:UnivariateFinite, ::AbstractArray{#s159,1} where #s159<:(Union{CategoricalString{U}, CategoricalValue{#s15,U} where #s15} where U)) at C:\Users\RTX2080\.julia\packages\MLJBase\JdmO3\src\measures\finite.jl:36

The cross_entropy metric is for use with probabilistic predictors. XGBoostClassifier is a probabilistic model (subtypes Probabilistic) which means predict should return a vector of distribution objects (UnivariateFinite objects in this case). Looks like your predictions are point values, ie your model is doing deterministic predictions. (I would expect a more informative error to be thrown in your case, but perhaps this is because JLBoostRegressor is not appropriately subtyped, or is missing an target_scitype declaration.) So, you should use an appropriate measure, for example misclassification_rate, accuracy, or FScore():

julia> measures() do m
       m.prediction_type == :deterministic &&
       m.target_scitype == AbstractVector{<:Finite}
       end
15-element Array{NamedTuple{(:name, :target_scitype, :supports_weights, :prediction_type, :orientation, :reports_each_observation, :aggregation, :is_feature_dependent, :docstring, :distribution_type),T} where T<:Tuple,1}:
 (name = accuracy, ...)              
 (name = balanced_accuracy, ...)     
 (name = FScore(β), ...)             
 (name = fdr, ...)                   
 (name = fn, ...)                    
 (name = fnr, ...)                   
 (name = fp, ...)                    
 (name = fpr, ...)                   
 (name = misclassification_rate, ...)
 (name = npv, ...)                   
 (name = ppv, ...)                   
 (name = tn, ...)                    
 (name = tnr, ...)                   
 (name = tp, ...)                    
 (name = tpr, ...)                   

julia> info(accuracy)
accuracy; aliases: `accuracy`
(name = "accuracy",
 target_scitype = AbstractArray{#s166,1} where #s166<:Finite,
 supports_weights = true,
 prediction_type = :deterministic,
 orientation = :score,
 reports_each_observation = false,
 aggregation = MLJBase.Mean(),
 is_feature_dependent = false,
 docstring = "accuracy; aliases: `accuracy`",
 distribution_type = missing,)

But, like XGBoost, you may want to instead make your classifier probabilistic. See this section of the manual for details and the XGBoost implementation for a template.

Note however that learning_curve! is not efficient for XGBoost because the XGBoost implementation of the MLJ interface currently lacks an update method. This means the training is restarted from scratch each time the number of iterations is increased. For an example (of a tree boosting algorithm!) where the update method is implemented, see EvoTrees.jl.

1 Like

Thanks for the tip! For the record this is now working in the lastest master of MLJJLBoost.jl