Thanks for commenting on your experience. For the record, you can also get the tree printed by increasing the verbosity level of the fit:
julia> X, y = @load_iris
((sepal_length = [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9 … 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9], sepal_width = [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1 … 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0], petal_length = [1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5 … 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1], petal_width = [0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1 … 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]), CategoricalArrays.CategoricalString{UInt8}["setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa" … "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica"])
julia> mach = machine(@load(DecisionTreeClassifier), X, y)
julia> fit!(mach, verbosity=2)
[ Info: Training Machine{DecisionTreeClassifier} @ 1…47.
Feature 3, Threshold 2.45
L-> 1 : 50/50
R-> Feature 4, Threshold 1.75
L-> Feature 3, Threshold 4.95
L-> Feature 4, Threshold 1.65
L-> 2 : 47/47
R-> 3 : 1/1
R-> Feature 4, Threshold 1.55
L-> 3 : 3/3
R-> Feature 3, Threshold 5.449999999999999
L-> 2 : 2/2
R-> 3 : 1/1
R-> Feature 3, Threshold 4.85
L-> Feature 2, Threshold 3.1
L-> 3 : 2/2
R-> 2 : 1/1
R-> 3 : 43/43
Machine{DecisionTreeClassifier} @ 1…47
julia> fitted_params(mach)
(tree_or_leaf = Decision Tree
Leaves: 9
Depth: 5,
encoding = Dict("virginica"=>0x03,"setosa"=>0x01,"versicolor"=>0x02),)
The correct way for the user to access the learning parameters is to use fp=fitted_params(mach)
. In this case fp.encoding
tells you how to decode the printed tree levels. If you wanted to print without refitting, you could do DecisionTree.print_tree(fp.tree_or_leaf)
.
In general there are two ways to inspect outcomes of training: fitted_params(mach)
returns a named-tuple representing (in as user-friendly form as possible) the actual learned parameters, while report(mach)
returns everything else. In the case of DecisionTree, one could expand the report to include a method that prints the tree. (In the fit
method in “MLJModels/src/DecisionTree.jl” change report = (classes_seen=classes_seen,)
to
report = (classes_seen=classes_seen,
print_tree= () -> DecistionTree.print_tree(tree, model.display_depth))
Or something along those lines. Pull request welcome!