How to visualise the structure of the decision tree built by MLJ?

using RDatasets, DataFrames
iris = dataset("datasets", "iris")

using MLJ # using the MLJ framework
using MLJModels # loads the modesl MLJ can use e.g. linear regression, decision tree
tree_model = @load DecisionTreeClassifier verbosity=1
y, X = unpack(iris, ==(:Species), !=(:Species))
tree_machine = machine(tree_model, X, y)
fit!(tree_machine)

using DecisionTree
print_tree(tree_machine.model)

Normally with a DecisionTree model I can print it using print_tree but no so one fitted with MLJ.

The error is odd because `typeof(tree_machine.model) ==

┌ Info: A model type "DecisionTreeClassifier" is already loaded. 
│ No new code loaded. 
â”” @ MLJModels C:\Users\RTX2080\.julia\packages\MLJModels\5Qzge\src\loading.jl:43
┌ Info: Training Machine{DecisionTreeClassifier} @ 1…95.
â”” @ MLJ C:\Users\RTX2080\.julia\packages\MLJ\BEVGY\src\machines.jl:141
┌ Info: Not retraining Machine{DecisionTreeClassifier} @ 1…95.
│  It appears up-to-date. Use `force=true` to force retraining.
â”” @ MLJ C:\Users\RTX2080\.julia\packages\MLJ\BEVGY\src\machines.jl:148
MethodError: no method matching print_tree(::Machine{DecisionTreeClassifier})
Closest candidates are:
  print_tree(!Matched::Nothing) at C:\Users\RTX2080\.julia\packages\DecisionTree\y42n2\src\scikitlearnAPI.jl:390
  print_tree(!Matched::DecisionTreeRegressor) at C:\Users\RTX2080\.julia\packages\DecisionTree\y42n2\src\scikitlearnAPI.jl:389
  print_tree(!Matched::DecisionTree.DecisionTreeClassifier) at C:\Users\RTX2080\.julia\packages\DecisionTree\y42n2\src\scikitlearnAPI.jl:388
  ...

Stacktrace:
 [1] top-level scope at In[74]:14

In the doc I believe it’s explained that the machine is just a wrapper for the model (a container with just the hyperparameters) and essentially the “fitresults” (whatever is learned during the fitting process that needs to be passed on to the predict or transform method).

So in this case if you want to recuperate the actual trained tree, please have a look at tree_machine.fitresult; IIRC the first element of the tuple is the tree that you want. This is also returned maybe more transparently by fitted_params(tree_machine) as a namedtuple.

1 Like

That’s correct. Thanks.

print_tree(tree_machine.fitresult[1])

print_tree(fitted_params(tree_machine)[1])

I must admit I only tried to skim over the docs, I found this passage

The fitted_params method
A fitted_params method may be optionally overloaded. It’s purpose is to provide MLJ access to a user-friendly representation of the learned parameters of the model (as opposed to the hyperparameters). They must be extractable from fitresult.

After reading it, I wouldn’t have figured out that that’s where I can get the tree model object. So having a post here that people can more easily find the answer is valuable.

1 Like

Great, we still have a lot to do in order to clarify how to do the many things that users would want to do with an ML toolbox and having people trying things out and telling us where they find things un-intuitive is very helpful!

Generally the tutorials is the good first place to get started; though there’s still need for a fair bit of work there as well of course.

Thanks for commenting on your experience. For the record, you can also get the tree printed by increasing the verbosity level of the fit:

julia> X, y = @load_iris
((sepal_length = [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9  …  6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9], sepal_width = [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1  …  3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0], petal_length = [1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5  …  5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1], petal_width = [0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1  …  2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]), CategoricalArrays.CategoricalString{UInt8}["setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa"  …  "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica"])

julia> mach = machine(@load(DecisionTreeClassifier), X, y)

julia> fit!(mach, verbosity=2)
[ Info: Training Machine{DecisionTreeClassifier} @ 1…47.
Feature 3, Threshold 2.45
L-> 1 : 50/50
R-> Feature 4, Threshold 1.75
    L-> Feature 3, Threshold 4.95
        L-> Feature 4, Threshold 1.65
            L-> 2 : 47/47
            R-> 3 : 1/1
        R-> Feature 4, Threshold 1.55
            L-> 3 : 3/3
            R-> Feature 3, Threshold 5.449999999999999
                L-> 2 : 2/2
                R-> 3 : 1/1
    R-> Feature 3, Threshold 4.85
        L-> Feature 2, Threshold 3.1
            L-> 3 : 2/2
            R-> 2 : 1/1
        R-> 3 : 43/43
Machine{DecisionTreeClassifier} @ 1…47

julia> fitted_params(mach)
(tree_or_leaf = Decision Tree
Leaves: 9
Depth:  5,
 encoding = Dict("virginica"=>0x03,"setosa"=>0x01,"versicolor"=>0x02),)

The correct way for the user to access the learning parameters is to use fp=fitted_params(mach). In this case fp.encoding tells you how to decode the printed tree levels. If you wanted to print without refitting, you could do DecisionTree.print_tree(fp.tree_or_leaf).

In general there are two ways to inspect outcomes of training: fitted_params(mach) returns a named-tuple representing (in as user-friendly form as possible) the actual learned parameters, while report(mach) returns everything else. In the case of DecisionTree, one could expand the report to include a method that prints the tree. (In the fit method in “MLJModels/src/DecisionTree.jl” change report = (classes_seen=classes_seen,) to

report = (classes_seen=classes_seen, 
               print_tree=  () -> DecistionTree.print_tree(tree, model.display_depth))

Or something along those lines. Pull request welcome!

2 Likes

I am wondering whether we can plot the learned decision tree model within MLJ like the following cite suggests Visualize a Decision Tree in 4 Ways with Scikit-Learn and Python | MLJAR (via Python version of scikitlearn) instead of simply displaying it as ASCII characters as print_tree does.

Latest EvoTrees.jl version (0.5.2), which is integrated with MLJ, introduces a basic tree visualization:

plot(model, 3) # second argument refers to the ith tree of the model

7 Likes

Cool! Can the same visualization be applied to trees from DecisionTree.jl?

The plot recipe is unfortunately specific to EvoTrees structure. It would require some adaptations for other tree base models.

I suppose decision trees will be around for a while, possibly longer than EvoTrees.jl and DecisionTree.jl. Some AbstractDescisionTree interface, which would define visualizations, would be pretty awesome.