MLJ TunedModel with imbalanced classes

The TunedModel interface supports parameter - weights but not class_weights.
For classification problems with highly imbalanced classes, we need tuned model measures working with class weights.
Is it currently possible to use TunedModel with highly imbalanced data for classification

2 Likes

Most of the measures in MLJ do not directly support class weights. The only measures that currently support class weights are the multiclass-classification measures, as we can see by running the following code snippet:

julia> measures() do m
           m.supports_class_weights
       end
8-element Vector{NamedTuple{(:name, :instances, :human_name, :target_scitype, :supports_weights, :supports_class_weights, :prediction_type, :orientation, :reports_each_observation, :aggregation, :is_feature_dependent, :docstring, :distribution_type), T} where T<:Tuple}:
 (name = MulticlassFScore, instances = [macro_f1score, micro_f1score, multiclass_f1score], ...)
 (name = MulticlassFalseDiscoveryRate, instances = [multiclass_falsediscovery_rate, multiclass_fdr], ...)
 (name = MulticlassFalseNegativeRate, instances = [multiclass_false_negative_rate, multiclass_fnr, multiclass_miss_rate, multiclass_falsenegative_rate], ...)
 (name = MulticlassFalsePositiveRate, instances = [multiclass_false_positive_rate, multiclass_fpr, multiclass_fallout, multiclass_falsepositive_rate], ...)
 (name = MulticlassNegativePredictiveValue, instances = [multiclass_negative_predictive_value, multiclass_negativepredictive_value, multiclass_npv], ...)
 (name = MulticlassPrecision, instances = [multiclass_positive_predictive_value, multiclass_ppv, multiclass_positivepredictive_value, multiclass_recall], ...)
 (name = MulticlassTrueNegativeRate, instances = [multiclass_true_negative_rate, multiclass_tnr, multiclass_specificity, multiclass_selectivity, multiclass_truenegative_rate], ...)
 (name = MulticlassTruePositiveRate, instances = [multiclass_true_positive_rate, multiclass_tpr, multiclass_sensitivity, multiclass_recall, multiclass_hit_rate, multiclass_truepositive_rate], ...)

However, many of the measures support sample weights:

julia> length(measures(m -> m.supports_weights))
26

So what you can do is manually create a weight vector w for the sample weights where the weight of each sample is determined by the true class of the observation and then pass w to the weights keyword argument of TunedModel. Of course you have to pick one of the measures that supports sample weights. Use measures(m -> m.supports_weights) to list all the measures that support sample weights.

By the way, you might have figured this out already, but don’t forget to set resampling=StratifiedCV() if you’re dealing with imbalanced classes. Actually, it’s a good idea to use StratifiedCV for any classification problem.

Another option is to not use weights and instead use a measure that is less influenced by class imbalance, like balanced_accuracy or auc.