MLJ w/Scikitlearn: passing return_std to predict

Various Scikitlearn models accept return_std=true when calling predict, for example BayesianRidgeRegressor, see this example. For example, with a BayesianRidgeRegressor or similar machine, I would like to call
y_predict, y_std = predict(machine, X, return_std=true)
I have looked through ScikitLearn.jl and MLJScikitLearnInterface.jl and do not see anyway to make this work, but maybe I am missing something simple like the right way to pass additional arguments? Thanks.

You’re not missing something, there’s currently no way to pass that argument. It might be good to open an issue at MLJScikitLearnInterface to discuss this (and you could paste what follows).

I doubt that MLJ’s predict signature will be adapted to match this one but I’ll let @ablaom or @samuel_okon discuss that).

What could work is to pass the return_std as a new field of BayesianRidgeRegressor here MLJScikitLearnInterface.jl/linear-regressors.jl at 36882f14321e7e9889aac31447eeed0102eb052f · JuliaAI/MLJScikitLearnInterface.jl · GitHub

then pick that up at predict time here MLJScikitLearnInterface.jl/macros.jl at 36882f14321e7e9889aac31447eeed0102eb052f · JuliaAI/MLJScikitLearnInterface.jl · GitHub

this would also require ScikitLearn.jl to allow passing a return_std=true to predict, that might also require opening an issue there cc @cstjean

1 Like

Very helpful. I will follow up as suggested.

1 Like

@tlienart posted on the ScikitLearn.jl site showing that ScikitLearn already supports the return_std=true arg via kwargs. So perhaps the only barrier to using this from MLJ is to get MLJScikitLearnInterface.jl to support the call. I posted an issue on the MLJScikitLearnInterface site and it looks like sometime in January the package maintainer will look into this. So eventually there may be a fix. Meanwhile if anyone has a suggestion for a workaround for making the call from MLJ, I would be happy to hear about that. Thanks.

PS: Other than directly calling ScikitLearn.jl, which I will try, but would rather have a call within MLJ.

the pure MLJ path is likely not going to be trivial though I’m sure Anthony will provide good input on that front in January. The difficulty is that it would potentially require a change in the API (to allow kwargs to be passed to MLJBase.predict and the semantics of return_std are a bit weird so it would require a change that would potentially not generalise very well.

Another approach would be to add a BayesianRidge probabilistic model to MLJScikitLearn and make it return gaussian densities from which you can extract std, probably the cleanest path.

Lastly, since you were asking for a workaround, this std is very easy to compute: scikit-learn/ at dc580a8ef5ee2a8aea80498388690e2213118efd · scikit-learn/scikit-learn · GitHub so in the meantime you could just mimic that. The fitted_params(model) will give you e.g. the alpha etc that you need for the computation.

Thank you, I appreciate all of the help you have given on this. I agree that all of the Bayesian routines in SLK should probabilistic models in MLJ, that would be best in the long run.

For now, I managed to call ScikitLearn.jl directly. In doing that, I found that calling ScilearnKit.jl is easy from REPL, but to make the call from a function in the way that I needed, I had to resort to an obscure workaround:

Perhaps there is some other way, but I could not find it. It seems strange that something as basic as calling via a function requires an undocumented hack. In any case, I have a simple workaround for now that allows me to move ahead.

Not following this discussion in detail as on leave. However, perhaps the poster may find SossMLJ.jl useful. It provides some Bayesian models with MLJ interfaces. They are not registered, and so cannot be loaded with @load but can still be loaded explicitly. If I recall correctly, predict returns distribution-like objects, although they do not support the pdf interface (only rand ?). There is also a predict_joint for predicting a single multivariate distribution ; see the docs for details.