How to visualise the structure of the decision tree built by MLJ?

Thanks for commenting on your experience. For the record, you can also get the tree printed by increasing the verbosity level of the fit:

julia> X, y = @load_iris
((sepal_length = [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9  …  6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9], sepal_width = [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1  …  3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0], petal_length = [1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5  …  5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1], petal_width = [0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1  …  2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]), CategoricalArrays.CategoricalString{UInt8}["setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa"  …  "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica"])

julia> mach = machine(@load(DecisionTreeClassifier), X, y)

julia> fit!(mach, verbosity=2)
[ Info: Training Machine{DecisionTreeClassifier} @ 1…47.
Feature 3, Threshold 2.45
L-> 1 : 50/50
R-> Feature 4, Threshold 1.75
    L-> Feature 3, Threshold 4.95
        L-> Feature 4, Threshold 1.65
            L-> 2 : 47/47
            R-> 3 : 1/1
        R-> Feature 4, Threshold 1.55
            L-> 3 : 3/3
            R-> Feature 3, Threshold 5.449999999999999
                L-> 2 : 2/2
                R-> 3 : 1/1
    R-> Feature 3, Threshold 4.85
        L-> Feature 2, Threshold 3.1
            L-> 3 : 2/2
            R-> 2 : 1/1
        R-> 3 : 43/43
Machine{DecisionTreeClassifier} @ 1…47

julia> fitted_params(mach)
(tree_or_leaf = Decision Tree
Leaves: 9
Depth:  5,
 encoding = Dict("virginica"=>0x03,"setosa"=>0x01,"versicolor"=>0x02),)

The correct way for the user to access the learning parameters is to use fp=fitted_params(mach). In this case fp.encoding tells you how to decode the printed tree levels. If you wanted to print without refitting, you could do DecisionTree.print_tree(fp.tree_or_leaf).

In general there are two ways to inspect outcomes of training: fitted_params(mach) returns a named-tuple representing (in as user-friendly form as possible) the actual learned parameters, while report(mach) returns everything else. In the case of DecisionTree, one could expand the report to include a method that prints the tree. (In the fit method in “MLJModels/src/DecisionTree.jl” change report = (classes_seen=classes_seen,) to

report = (classes_seen=classes_seen, 
               print_tree=  () -> DecistionTree.print_tree(tree, model.display_depth))

Or something along those lines. Pull request welcome!

2 Likes