[ANN] NeuroTreeModels.jl - Differentiable tree-based models for tabular data

Announcement for the initial release of NeuroTreeModels.jl

NeuroTree based models comprise a collection of differentiable trees, in an attempt get both the performance benefits of boosted tree methods and flexibility of gradient based learning. A more comprehensive description can be found in the doc’s design section: NeuroTree - A differentiable tree operator for tabular data | NeuroTreeModels

Comprehensive bencharmks have been run against XGBoost, LightGBM, CatBoost and EvoTrees on 6 datasets commonly used in publications of ML methods on tabular data. Results and code to reproduce are found at MLBenchmarks.jl.

NeuroTree share similarities with Yandex’s Neural Oblivious Decision Ensemble. Key differences include:

  • Full binary trees (rather than oblibious ones).
  • Rely on a simple NeuroTree operator that behaves similarly to a Dense operator for tabular, 2D input data. Such operator can be composed like a Dense operator in Flux chains to compose more complex models, like stack of trees, or combination with any other operators.
28 Likes

Thank you for the package and the benchmark results. I don’t know CatBoost before, whose performance looks even better and promising.

1 Like

Some meaningful updates with the v1.3.0 relase:

  • The kwarg device (:cpu / :gpu) is moved from NeuroTreeRegressor to fit. Same for gpuID.
  • Removal of outsize argument.
  • Introduction of NeuroTreeClassififier, respecting the MLJ interface
  • It’s no longer need to specify the number of classes (formerly through the deprecated outsize kwarg). It’s automatically detected throuh the target variable number of levels.
  • Classification tasks (using NeuroTreeClassififier) now require the target variable to be Categorical
3 Likes

Just saw the talk. Fantastic. I have been doing research into differentable tree modesl

I finally have some time and might restart effort on JLBoost.jl which in contrast to XGBoost and Cat Boost allows for experimentation better because the structure is more hackable and allows the user to tweak many aspects of the boosting process.

2 Likes

Which talk? Is there a link to video?

2 Likes