Version 1.2 of the
SIRUS.jl package has just been registered. Since version 1.2, the package can be used for both classification and regression.
The problem with modern day machine learning is that we appear to get further and further away from model explainability/interpretability. With large language models (LLMs), for example, it impossible to explain why exactly certain decisions are made. This is a problem in high-stakes contexts where the decisions made by the model have real-world effects on people. Unfortunately, many people tend to agree that semi-interpretable models such as random forests with SHAP are “good enough”. However, we don’t have to accept semi-interpretable models because we can use fully interpretable rule-based models.
SIRUS.jl is a pure Julia implementation of the SIRUS algorithm by Bénard et al. (2021). The algorithm is a rule-based machine learning model which makes it fully interpretable. The algorithm does this by firstly fitting a random forests and then converting this forest to rules. Furthermore, the algorithm is stable and achieves a predictive performance that is comparable to LightGBM, a state-of-the-art gradient boosting model created by Microsoft. Interpretability, stability, and predictive performance are described in more detail below.
To see that the algorithm is interpretable, let’s take the Haberman dataset, which is a dataset about whether patients survived cancer. If we fit the SIRUS algorithm on the full dataset, we get the following fitted model:
StableRules model with 8 rules: if X[i, :nodes] < 14.0 then 0.152 else 0.094 + if X[i, :nodes] < 8.0 then 0.085 else 0.039 + if X[i, :nodes] < 4.0 then 0.077 else 0.044 + if X[i, :nodes] < 2.0 then 0.071 else 0.047 + if X[i, :nodes] < 1.0 then 0.072 else 0.057 + if X[i, :year] < 1960.0 then 0.018 else 0.023 + if X[i, :age] < 38.0 then 0.029 else 0.023 + if X[i, :age] < 42.0 then 0.052 else 0.043 and 2 classes: [0.0, 1.0]. Note: showing only the probability for class 1.0 since class 0.0 has probability 1 - p.
This shows that the model contains 8 rules. The first rule, for example, can be explained as:
If the number of detected auxillary
nodesis lower than 14, then take 0.152, otherwise take 0.094.
This is done for all 8 rules and the total score is summed to get a prediction. In essence, the first rule says that if there are less than 14 auxillary
nodes detected, then the patient will most likely survive (
class == 1.0). This makes sense because this basically says that if there are many auxillary nodes detected, then it’s (unfortunately) less likely that the patient will survive.
This model is fully interpretable because there are few rules which can all be interpreted in isolation reasonably well. Random forests, in contrasts, consist of hundreds to thousands of trees, which are not interpretable due to the large amount of trees. A common workaround for this is to use SHAP or Shapley values to visualize the fitted model. The problem with those methods is that they do not allow full reproducibility of the predictions. For example, if we would inspect the fitted model on the aforementioned Haberman dataset via SHAP, then we could learn feature importances. In practice that would mean that we could tell which features were important. In many real-world situations this is not enough. Imagine having to tell a patient that was misdiagnosed by the model: “Sorry about our prediction, we were wrong and we didn’t really know why. Only that
nodes is an important feature in the model, but we don’t know whether this played a large role in your situation.” We had to solve this problem for our research too. When trying to select special forces recruits, we wanted to interpret the model fully. Similarly to the previous example, sending a recruit away needs a stronger defense than “The model said no.” Instead, we wanted to be able to be specific: “You took more than 700 seconds to run 2800 meters and your sprint time was above 30.2; together this caused the model to give you a low probability of success.” For more information about this study and the SIRUS application, see our paper on Predicting Special Forces Dropout.
Another problem that the SIRUS algorithm solves is that of model stability. A stable model is defined as a model which leads to similar conclusions for small changes to data (Yu, 2020). Unstable models can be difficult to apply in practice since they might require processes to constantly change. Also, they are considered less trustworthy. For example, going back to our special forces example, if we present a model to the military that selects recruits on features U and V in one year and on features W and Z in another, then that does not inspire confidence in the organization.
Having said that, most statistical models are quite stable since a higher stability is often correlated to a higher predictive performance. Put differently, an unstable model by definition leads to different conclusions for small changes to the data and, hence, small changes to the data can cause a sudden drop in predictive performance. One model which suffers from a low stability is a decision tree. This is because a decision tree will first create the root node of the tree, so a small change in the data can cause the root, and therefore the rest, of the tree to be completely different. The SIRUS algorithm has solved the instability of random forests by “stabilizing the trees” (Bénard et al., 2021) and the authors have proven mathematically that the stabilization works. For people who are interested in the stabilization approach, see the original paper or Section Tree Stabilization in the Binary Classification Example.
As stated above, the algorithm converts a large number of trees to a small number of rules to improve interpretability. This comes at a small performance cost. For example, these are the cross-validated scores on the Haberman dataset for various models (source: Binary Classification Example):
|Model||AUC ± 1.96*SE||Interpretability||Stability|
||0.71 ± 0.06||Medium||High|
||0.67 ± 0.05||Medium||High|
||0.63 ± 0.06||High||Low|
||0.71 ± 0.05||Low||High|
||0.70 ± 0.09||High||High|
||0.67 ± 0.07||High||High|
||0.67 ± 0.07||High||High|
This shows that the SIRUS algorithm performs very comparable to the state-of-the-art LGBM classifier by Microsoft. The tree depths are set to at most 2 because rules which belong to a depth of 3 will (almost) never show up in the final model.
Furthermore, the SIRUS algorithm also performed almost as good as the LightGBM and outperformed a regularized linear classifier in our paper, as shown in Figure 1:
Note that the area under the curve scores are shown in the bottom right of each subplot.
Finally, many benchmarks run as part of the test suite. The results can be found in the summary of the GitHub Actions for the CI workflow. These are the results for version 1.2.0:
30×7 DataFrame Row │ Dataset Model Hyperparameters nfolds measure score 1.96*SE │ String String String Int64 String String String ─────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────── 1 │ blobs LGBMClassifier (;) 10 auc 0.99 0.01 2 │ blobs LGBMClassifier (max_depth = 2,) 10 auc 0.99 0.01 3 │ blobs StableForestClassifier (max_depth = 2,) 10 auc 1.00 0.00 4 │ blobs StableRulesClassifier (max_depth = 2, max_rules = 30) 10 auc 1.00 0.00 5 │ blobs StableRulesClassifier (max_depth = 2, max_rules = 10) 10 auc 1.00 0.00 6 │ titanic LGBMClassifier (;) 10 auc 0.87 0.03 7 │ titanic LGBMClassifier (max_depth = 2,) 10 auc 0.85 0.02 8 │ titanic StableForestClassifier (max_depth = 2,) 10 auc 0.85 0.02 9 │ titanic StableRulesClassifier (max_depth = 2, max_rules = 30) 10 auc 0.83 0.02 10 │ titanic StableRulesClassifier (max_depth = 2, max_rules = 10) 10 auc 0.82 0.02 11 │ haberman LGBMClassifier (;) 10 auc 0.71 0.06 12 │ haberman LGBMClassifier (max_depth = 2,) 10 auc 0.67 0.05 13 │ haberman StableForestClassifier (max_depth = 2,) 10 auc 0.71 0.05 14 │ haberman StableRulesClassifier (max_depth = 2, max_rules = 30) 10 auc 0.70 0.08 15 │ haberman StableRulesClassifier (max_depth = 2, max_rules = 10) 10 auc 0.67 0.07 16 │ iris LGBMClassifier (;) 10 accuracy 0.94 0.04 17 │ iris LGBMClassifier (max_depth = 2,) 10 accuracy 0.94 0.04 18 │ iris StableForestClassifier (max_depth = 2,) 10 accuracy 0.95 0.04 19 │ iris StableRulesClassifier (max_depth = 2, max_rules = 30) 10 accuracy 0.73 0.12 20 │ iris StableRulesClassifier (max_depth = 2, max_rules = 10) 10 accuracy 0.68 0.10 21 │ make_regression LGBMRegressor (;) 10 R² 0.78 0.03 22 │ make_regression LGBMRegressor (max_depth = 2,) 10 R² 0.61 0.04 23 │ make_regression StableForestRegressor (max_depth = 2,) 10 R² 0.67 0.05 24 │ make_regression StableRulesRegressor (max_depth = 2, max_rules = 30) 10 R² 0.55 0.04 25 │ make_regression StableRulesRegressor (max_depth = 2, max_rules = 10) 10 R² 0.57 0.06 26 │ boston LGBMRegressor (;) 10 R² 0.70 0.06 27 │ boston LGBMRegressor (max_depth = 2,) 10 R² 0.63 0.07 28 │ boston StableForestRegressor (max_depth = 2,) 10 R² 0.67 0.08 29 │ boston StableRulesRegressor (max_depth = 2, max_rules = 30) 10 R² 0.55 0.07 30 │ boston StableRulesRegressor (max_depth = 2, max_rules = 10) 10 R² 0.63 0.10
This also shows that performance scores are similar to LightGBM on various datasets.
Arguably, the slightly lower performance is not a bad trade for a higher interpretability. More interpretable models can lead to better predictive performance in the real-world because it’s easier to verify the model and make adjustments if needed.
If you are interested in applying the model to your dataset, you can use them via the
MLJ.jl interface. For example, this is the code to fit the model on the full Haberman dataset:
model = StableRulesClassifier(; max_depth=1, max_rules=8) mach = machine(model, X, y) fit!(mach)
To evaluate model performance via cross-validation, use
resampling = CV(; nfolds, shuffle=true, rng=_rng()) acceleration = MLJ.CPUThreads() evaluate(model, X, y; acceleration, resampling, measure=auc)
The Haberman dataset is a binary classification dataset. The approach for regression datasets is the same, but then use the
StableRulesRegressor. For more information about the models and hyperparameters, see the API docs.