Example of the use DecisionTree.permutation_importance function based on MLJ

Can someone give me a small example of how to use the function DecisionTree.permutation_importance()
using an MLJ Tree machine (Tree = @load DecisionTreeClassifier pkg=DecisionTree)?
A Python similar function is sklearn.inspection.permutation_importance — scikit-learn 1.2.0 documentation

cc @samuel_okon

@lgmendes unfortunately, there isn’t a standard way to do this with MLJ machines yet.
This wasn’t directly exposed from the DecisionTree package, because it’s a method that could be applied to any other predictive model, not just DecisionTree classifiers or DecisionTree regressors. Hence, there are plans to add this functionality for all supervised machines.
For now, if you can use the following code do evaluate permutation importance. You may wish to replace f1score with any other appropriate score from MLJ measures.

julia> using MLJ, DecisionTree

julia> X, y = @load_crabs
((FL = [8.1, 8.8, 9.2, 9.6, 9.8, 10.8, 11.1, 11.6, 11.8, 11.8  …  20.3, 20.5, 20.6, 20.9, 21.3, 21.4, 21.7, 21.9, 22.5, 23.1], RW = [6.7, 7.7, 7.8, 7.9, 8.0, 9.0, 9.9, 9.1, 9.6, 10.5  …  16.0, 17.5, 17.5, 16.5, 18.4, 18.0, 17.1, 17.2, 17.2, 20.2], CL = [16.1, 18.1, 19.0, 20.1, 20.3, 23.0, 23.8, 24.5, 24.2, 25.2  …  39.4, 40.0, 41.5, 39.9, 43.8, 41.2, 41.7, 42.6, 43.0, 46.2], CW = [19.0, 20.8, 22.4, 23.1, 23.0, 26.5, 27.1, 28.4, 27.8, 29.3  …  44.1, 45.5, 46.2, 44.7, 48.4, 46.2, 47.2, 47.4, 48.7, 52.5], BD = [7.0, 7.4, 7.7, 8.2, 8.2, 9.8, 9.8, 10.4, 9.7, 10.3  …  18.0, 19.2, 19.2, 17.5, 20.0, 18.7, 19.6, 19.5, 19.8, 21.1]), CategoricalArrays.CategoricalValue{String, UInt32}["B", "B", "B", "B", "B", "B", "B", "B", "B", "B"  …  "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"])

julia> Tree = @load DecisionTreeClassifier pkg=DecisionTree
[ Info: For silent loading, specify `verbosity=0`.
import MLJDecisionTreeInterface ✔
MLJDecisionTreeInterface.DecisionTreeClassifier

julia> mach = machine(Tree(), X, y) |> MLJ.fit!
[ Info: Training machine(DecisionTreeClassifier(max_depth = -1, …), …).
trained Machine; caches model-specific representations of data
  model: DecisionTreeClassifier(max_depth = -1, …)
  args:
    1:  Source @900 ⏎ Table{AbstractVector{Continuous}}
    2:  Source @581 ⏎ AbstractVector{Multiclass{2}}


julia> fitted_tree = fitted_params(mach).tree
Decision Tree
Leaves: 17
Depth:  9

julia> class_list = MLJ.int(classes(y))
2-element Vector{UInt32}:
 0x00000001
 0x00000002

julia> function wrapped_f1score(tree, ylabels_val, X_val_matrix)
       y_pred = categorical(DecisionTree.apply_tree(tree, X_val_matrix), levels=class_list, ordered=true)
       return MLJ.f1score(y_pred, categorical(ylabels_val, levels=class_list, ordered=true))
       end
wrapped_f1score (generic function with 1 method)

julia> DecisionTree.permutation_importance(fitted_tree, MLJ.int(y), MLJ.matrix(X), wrapped_f1score)
(mean = [0.4663372650034014, 0.0, 0.02936705464649414, 0.32679053711036654, 0.2246004559048037],
 std = [0.024054791382060164, 0.0, 0.008290294746685429, 0.017755155428814066, 0.018309476843891812],
 scores = [0.4607843137254902 0.4926829268292683 0.4455445544554455; 0.0 0.0 0.0; … ; 0.31999999999999995 0.31343283582089554 0.34693877551020413; 0.20346320346320346 0.23555555555555552 0.23478260869565215],)

For more info see ?DecisionTree.permutation_importance

1 Like