[ANN] SIRUS.jl v1.2: Interpretable Machine Learning via Rule Extraction

Version 1.2 of the SIRUS.jl package has just been registered. Since version 1.2, the package can be used for both classification and regression.

The problem with modern day machine learning is that we appear to get further and further away from model explainability/interpretability. With large language models (LLMs), for example, it impossible to explain why exactly certain decisions are made. This is a problem in high-stakes contexts where the decisions made by the model have real-world effects on people. Unfortunately, many people tend to agree that semi-interpretable models such as random forests with SHAP are “good enough”. However, we don’t have to accept semi-interpretable models because we can use fully interpretable rule-based models.

SIRUS.jl is a pure Julia implementation of the SIRUS algorithm by BĂ©nard et al. (2021). The algorithm is a rule-based machine learning model which makes it fully interpretable. The algorithm does this by firstly fitting a random forests and then converting this forest to rules. Furthermore, the algorithm is stable and achieves a predictive performance that is comparable to LightGBM, a state-of-the-art gradient boosting model created by Microsoft. Interpretability, stability, and predictive performance are described in more detail below.

Interpretability

To see that the algorithm is interpretable, let’s take the Haberman dataset, which is a dataset about whether patients survived cancer. If we fit the SIRUS algorithm on the full dataset, we get the following fitted model:

StableRules model with 8 rules:
 if X[i, :nodes] < 14.0 then 0.152 else 0.094 +
 if X[i, :nodes] < 8.0 then 0.085 else 0.039 +
 if X[i, :nodes] < 4.0 then 0.077 else 0.044 +
 if X[i, :nodes] < 2.0 then 0.071 else 0.047 +
 if X[i, :nodes] < 1.0 then 0.072 else 0.057 +
 if X[i, :year] < 1960.0 then 0.018 else 0.023 +
 if X[i, :age] < 38.0 then 0.029 else 0.023 +
 if X[i, :age] < 42.0 then 0.052 else 0.043
and 2 classes: [0.0, 1.0]. 
Note: showing only the probability for class 1.0 since class 0.0 has probability 1 - p.

This shows that the model contains 8 rules. The first rule, for example, can be explained as:

If the number of detected auxillary nodes is lower than 14, then take 0.152, otherwise take 0.094.

This is done for all 8 rules and the total score is summed to get a prediction. In essence, the first rule says that if there are less than 14 auxillary nodes detected, then the patient will most likely survive (class == 1.0). This makes sense because this basically says that if there are many auxillary nodes detected, then it’s (unfortunately) less likely that the patient will survive.

This model is fully interpretable because there are few rules which can all be interpreted in isolation reasonably well. Random forests, in contrasts, consist of hundreds to thousands of trees, which are not interpretable due to the large amount of trees. A common workaround for this is to use SHAP or Shapley values to visualize the fitted model. The problem with those methods is that they do not allow full reproducibility of the predictions. For example, if we would inspect the fitted model on the aforementioned Haberman dataset via SHAP, then we could learn feature importances. In practice that would mean that we could tell which features were important. In many real-world situations this is not enough. Imagine having to tell a patient that was misdiagnosed by the model: “Sorry about our prediction, we were wrong and we didn’t really know why. Only that nodes is an important feature in the model, but we don’t know whether this played a large role in your situation.” We had to solve this problem for our research too. When trying to select special forces recruits, we wanted to interpret the model fully. Similarly to the previous example, sending a recruit away needs a stronger defense than “The model said no.” Instead, we wanted to be able to be specific: “You took more than 700 seconds to run 2800 meters and your sprint time was above 30.2; together this caused the model to give you a low probability of success.” For more information about this study and the SIRUS application, see our paper on Predicting Special Forces Dropout.

Stability

Another problem that the SIRUS algorithm solves is that of model stability. A stable model is defined as a model which leads to similar conclusions for small changes to data (Yu, 2020). Unstable models can be difficult to apply in practice since they might require processes to constantly change. Also, they are considered less trustworthy. For example, going back to our special forces example, if we present a model to the military that selects recruits on features U and V in one year and on features W and Z in another, then that does not inspire confidence in the organization.

Having said that, most statistical models are quite stable since a higher stability is often correlated to a higher predictive performance. Put differently, an unstable model by definition leads to different conclusions for small changes to the data and, hence, small changes to the data can cause a sudden drop in predictive performance. One model which suffers from a low stability is a decision tree. This is because a decision tree will first create the root node of the tree, so a small change in the data can cause the root, and therefore the rest, of the tree to be completely different. The SIRUS algorithm has solved the instability of random forests by “stabilizing the trees” (Bénard et al., 2021) and the authors have proven mathematically that the stabilization works. For people who are interested in the stabilization approach, see the original paper or Section Tree Stabilization in the Binary Classification Example.

Predictive Performance

As stated above, the algorithm converts a large number of trees to a small number of rules to improve interpretability. This comes at a small performance cost. For example, these are the cross-validated scores on the Haberman dataset for various models (source: Binary Classification Example):

Model AUC ± 1.96*SE Interpretability Stability
LGBMClassifier() 0.71 ± 0.06 Medium High
LGBMClassifier(; max_depth=2) 0.67 ± 0.05 Medium High
DecisionTreeClassifier(; max_depth=2) 0.63 ± 0.06 High Low
StableForestClassifier(; max_depth=2) 0.71 ± 0.05 Low High
StableRulesClassifier(; max_depth=2, max_rules=25) 0.70 ± 0.09 High High
StableRulesClassifier(; max_depth=2, max_rules=10) 0.67 ± 0.07 High High
StableRulesClassifier(; max_depth=1, max_rules=25) 0.67 ± 0.07 High High

This shows that the SIRUS algorithm performs very comparable to the state-of-the-art LGBM classifier by Microsoft. The tree depths are set to at most 2 because rules which belong to a depth of 3 will (almost) never show up in the final model.

Furthermore, the SIRUS algorithm also performed almost as good as the LightGBM and outperformed a regularized linear classifier in our paper, as shown in Figure 1:

Note that the area under the curve scores are shown in the bottom right of each subplot.

Finally, many benchmarks run as part of the test suite. The results can be found in the summary of the GitHub Actions for the CI workflow. These are the results for version 1.2.0:

30Ă—7 DataFrame
 Row │ Dataset          Model                   Hyperparameters                  nfolds  measure   score   1.96*SE
     │ String           String                  String                           Int64   String    String  String
─────┼─────────────────────────────────────────────────────────────────────────────────────────────────────────────
   1 │ blobs            LGBMClassifier          (;)                                  10  auc       0.99    0.01
   2 │ blobs            LGBMClassifier          (max_depth = 2,)                     10  auc       0.99    0.01
   3 │ blobs            StableForestClassifier  (max_depth = 2,)                     10  auc       1.00    0.00
   4 │ blobs            StableRulesClassifier   (max_depth = 2, max_rules = 30)      10  auc       1.00    0.00
   5 │ blobs            StableRulesClassifier   (max_depth = 2, max_rules = 10)      10  auc       1.00    0.00
   6 │ titanic          LGBMClassifier          (;)                                  10  auc       0.87    0.03
   7 │ titanic          LGBMClassifier          (max_depth = 2,)                     10  auc       0.85    0.02
   8 │ titanic          StableForestClassifier  (max_depth = 2,)                     10  auc       0.85    0.02
   9 │ titanic          StableRulesClassifier   (max_depth = 2, max_rules = 30)      10  auc       0.83    0.02
  10 │ titanic          StableRulesClassifier   (max_depth = 2, max_rules = 10)      10  auc       0.82    0.02
  11 │ haberman         LGBMClassifier          (;)                                  10  auc       0.71    0.06
  12 │ haberman         LGBMClassifier          (max_depth = 2,)                     10  auc       0.67    0.05
  13 │ haberman         StableForestClassifier  (max_depth = 2,)                     10  auc       0.71    0.05
  14 │ haberman         StableRulesClassifier   (max_depth = 2, max_rules = 30)      10  auc       0.70    0.08
  15 │ haberman         StableRulesClassifier   (max_depth = 2, max_rules = 10)      10  auc       0.67    0.07
  16 │ iris             LGBMClassifier          (;)                                  10  accuracy  0.94    0.04
  17 │ iris             LGBMClassifier          (max_depth = 2,)                     10  accuracy  0.94    0.04
  18 │ iris             StableForestClassifier  (max_depth = 2,)                     10  accuracy  0.95    0.04
  19 │ iris             StableRulesClassifier   (max_depth = 2, max_rules = 30)      10  accuracy  0.73    0.12
  20 │ iris             StableRulesClassifier   (max_depth = 2, max_rules = 10)      10  accuracy  0.68    0.10
  21 │ make_regression  LGBMRegressor           (;)                                  10  R²        0.78    0.03
  22 │ make_regression  LGBMRegressor           (max_depth = 2,)                     10  R²        0.61    0.04
  23 │ make_regression  StableForestRegressor   (max_depth = 2,)                     10  R²        0.67    0.05
  24 │ make_regression  StableRulesRegressor    (max_depth = 2, max_rules = 30)      10  R²        0.55    0.04
  25 │ make_regression  StableRulesRegressor    (max_depth = 2, max_rules = 10)      10  R²        0.57    0.06
  26 │ boston           LGBMRegressor           (;)                                  10  R²        0.70    0.06
  27 │ boston           LGBMRegressor           (max_depth = 2,)                     10  R²        0.63    0.07
  28 │ boston           StableForestRegressor   (max_depth = 2,)                     10  R²        0.67    0.08
  29 │ boston           StableRulesRegressor    (max_depth = 2, max_rules = 30)      10  R²        0.55    0.07
  30 │ boston           StableRulesRegressor    (max_depth = 2, max_rules = 10)      10  R²        0.63    0.10

This also shows that performance scores are similar to LightGBM on various datasets.

Arguably, the slightly lower performance is not a bad trade for a higher interpretability. More interpretable models can lead to better predictive performance in the real-world because it’s easier to verify the model and make adjustments if needed.

Example

If you are interested in applying the model to your dataset, you can use them via the MLJ.jl interface. For example, this is the code to fit the model on the full Haberman dataset:

model = StableRulesClassifier(; max_depth=1, max_rules=8)
mach = machine(model, X, y)
fit!(mach)

To evaluate model performance via cross-validation, use MLJ.evaluate:

resampling = CV(; nfolds, shuffle=true, rng=_rng())
acceleration = MLJ.CPUThreads()
evaluate(model, X, y; acceleration, resampling, measure=auc)

The Haberman dataset is a binary classification dataset. The approach for regression datasets is the same, but then use the StableRulesRegressor. For more information about the models and hyperparameters, see the API docs.

24 Likes

This is really cool! The task reminded me of symbolic regression, so I tried using SymbolicRegression.jl to perform a similar task. With a little bit of tuning there (and none for SIRUS), I got it to produce simple rules for regression in the Boston housing task, with an equivalent testset error to SIRUS. I only allowed SymbolicRegression to use +, -, ,<, and > to match SIRUS’s rules. Both perform much worse than XGBoost in this case though (but maybe they just need tuning).

I made a pluto notebook but don’t have a good way to share the rendered results (can’t upload HTML here). But here is the notebook: SIRUS_Symbolic_Regression.jl (51.5 KB)

I wonder if there’s a way to use SymbolicRegression as some kind of backend to SIRUS to help power the rule search, or if they are just two different approaches to similar problems. SymbolicRegression can handle more complicated expressions (you can give it arbitrary binary operations to use, not just + and >), which is why I think it might be a good way to add more expressive power to SIRUS.

4 Likes

Just to say, I switched to the Ames housing dataset and tried a bit more. There the regression task is to predict home prices from various variables. I dropped all categorical variables for simplicity, and this time allowed SymbolicRegression to use * and /. I got XGBoost having RMS error in price of ~29k, SymbolicRegression of ~31k, and SIRUS of ~58k, after playing with the parameters a bit. I’m not sure what test/train split they used, but the XGBoost and SymbolicRegression numbers seem competitive which what I found in this old Kaggle contest: Ames Housing Data | Kaggle.

The formula SymbolicRegression ended up with is

(((((((YearBuilt + -0.48597169501749404) * 1.2438029869962868) + GrLivArea) - ((BsmtFinSF1 - (BedroomAbvGr - LotArea)) / (-2.0259896659909704 / 0.7342618541159212))) + GrLivArea) + TotalBsmtSF) / (3.8119719489105597 - YearBuilt))

which if I simplify the parenthesis, I believe is

((YearBuilt + -0.48597169501749404) * 1.2438029869962868 + 
 GrLivArea - 
 ((BsmtFinSF1 - (BedroomAbvGr - LotArea)) / (-2.0259896659909704 / 0.7342618541159212)) +
 GrLivArea +
 TotalBsmtSF)
/ (3.8119719489105597 - YearBuilt)

which seems pretty nice and simple (where here the variable names refer to Z-scored transformations, which was necessary for both SymbolicRegression and SIRUS to perform OK). Sirus ended up with these rules:

StableRules model with 20 rules:
 if X[i, :x2ndFlrSF] ≥ 1.0926578 & X[i, :BsmtFinSF1] ≥ 1.5382147 then 0.028 else 0.009 +
 if X[i, :x2ndFlrSF] ≥ 0.82368124 & X[i, :BsmtFinSF1] ≥ 1.5382147 then 0.037 else 0.012 +
 if X[i, :TotalBsmtSF] < 1.3811567 then -0.018 else 0.423 +
 if X[i, :TotRmsAbvGrd] ≥ 1.5452067 & X[i, :BsmtFinSF1] ≥ 1.5382147 then 0.042 else 0.013 +
 if X[i, :BsmtFinSF1] < 1.5382147 then 0.018 else 0.466 +
 if X[i, :GarageCars] < 0.31470308 then -0.028 else 0.339 +
 if X[i, :TotRmsAbvGrd] ≥ 1.5452067 & X[i, :TotalBsmtSF] ≥ 1.3811567 then 0.049 else 0.016 +
 if X[i, :GrLivArea] ≥ 1.4085871 & X[i, :TotalBsmtSF] ≥ 1.3811567 then 0.075 else 0.023 +
 if X[i, :YrSold] ≥ 0.8789953 & X[i, :GarageCars] ≥ 0.31470308 then 0.021 else 0.006 +
 if X[i, :x1stFlrSF] ≥ 1.43799 & X[i, :x2ndFlrSF] ≥ 1.4581902 then 0.029 else 0.009 +
 if X[i, :GrLivArea] ≥ 1.4085871 & X[i, :BsmtFinSF1] ≥ 1.5382147 then 0.056 else 0.017 +
 if X[i, :GrLivArea] < 0.84535515 then -0.024 else 0.245 +
 if X[i, :YearBuilt] ≥ 1.174533 & X[i, :GrLivArea] ≥ 1.4085871 then 0.033 else 0.013 +
 if X[i, :YearBuilt] ≥ 1.174533 & X[i, :GrLivArea] ≥ 0.84535515 then 0.047 else 0.014 +
 if X[i, :GrLivArea] < 1.4085871 then -0.033 else 0.315 +
 if X[i, :TotRmsAbvGrd] ≥ 1.5452067 & X[i, :GarageYrBlt] ≥ 0.6679906 then 0.061 else 0.015 +
 if X[i, :TotRmsAbvGrd] ≥ 1.5452067 & X[i, :TotalBsmtSF] ≥ 0.9063819 then 0.057 else 0.018 +
 if X[i, :TotRmsAbvGrd] ≥ 1.5452067 & X[i, :YearBuilt] ≥ 0.6748743 then 0.061 else 0.015 +
 if X[i, :x2ndFlrSF] < 1.4581902 then 0.007 else 0.353 +
 if X[i, :TotalBsmtSF] < 0.9063819 then -0.025 else 0.218

While XGBoost gives these feature importances for the 10 most important features:

feature gain weight cover total_gain total_cover
String Float32 Float32 Float32 Float32 Float32
1 “GarageCars” 4.24146f11 10.0 557.4 4.24146f12 5574.0
2 “GrLivArea” 1.41052f10 209.0 287.641 2.94798f12 60117.0
3 “HalfBath” 6.13618f9 14.0 157.0 8.59065f10 2198.0
4 “TotalBsmtSF” 4.58161f9 186.0 225.973 8.5218f11 42031.0
5 “TotRmsAbvGrd” 3.40563f9 48.0 136.646 1.6347f11 6559.0
6 “Fireplaces” 3.15909f9 40.0 228.825 1.26364f11 9153.0
7 “BsmtFinSF1” 3.06287f9 162.0 175.531 4.96186f11 28436.0
8 “YearBuilt” 2.94103f9 202.0 133.02 5.94088f11 26870.0
9 “YearRemodAdd” 2.3613f9 154.0 158.844 3.6364f11 24462.0
10 “FullBath” 2.05867f9 15.0 356.6 3.08801f10 5349.0

Here is my notebook: SIRUS_Symbolic_Regression_Ames.jl (54.1 KB)

9 Likes

That’s really nice. I’m definitely gonna look into this! Thanks

1 Like

This is a neat package @rikh!

This seems a bit similar to signed iterative Random Forests (s-iRF) (paper, github), which I have my eye on for potential use with feature selection. I’ve been hoping to see a high-performance Julia implementation of random intersection trees (gRIT) / iRF / s-iRF for a while now as a feature selection technique. Perhaps this might be of interest to you?

See also: Iterative random forests to discover predictive and stable high-order interactions (PNAS), and iRF.py

1 Like

Oh, one other thing @rikh. I’d be interested in applying this to a data set that requires a custom bootstrapping approach (extremely imbalanced, groups of rows are highly correlated and can’t be split across the in-bag/OOB samples). Is there any scope to provide a custom bootstrap sampling function, or build the trees separately within a custom bootstrap loop and then feed them to SIRUS.jl to extract the rules?

1 Like

Yes it does look like it. I find that paper extremely hard to read because they require medical/chemical domain knowledge for their algorithm description.

This one is much clearer. It looks like the SIRUS algorithm is very similar indeed. If you want to benchmark that algorithm iRF (which was removed from CRAN?) feel free to open a PR and re-use the R benchmarking logic in https://github.com/rikhuijzer/SIRUS.jl/blob/main/test/rcall.jl. I won’t guarantee that I’ll merge the PR, but at least we can re-use the benchmarking setup and see how well it performs.

Given that SIRUS.jl and Julia are often used for smallish research datasets, that feature would make sense yes. However, I and many people don’t need that feature so I probably won’t implement it myself. Feel free to open a PR. If the implementation does not add too much complexity then it will likely be merged. Code for an MLJ wrapper for imbalanced datasets is at Oversampling and undersampling · Issue #661 · alan-turing-institute/MLJ.jl · GitHub (I’m not sure you need this, but the link may be useful if you do).

1 Like