I would like to know, how to train MLJ model on GPU.


using MLJ
X = MLJ.table(rand(100, 10));
y = 2X.x1 - X.x2 + 0.05*rand(100);

@load LinearRegressor pkg=MLJLinearModels verbosity=0;
model = LinearRegressor()

mach = machine(model, X, y);

params = fitted_params(mach)
params.coefs # coefficient of the regression with names
params.intercept # intercept

Xnew  = MLJ.table(rand(3, 10));
ypred = predict(mach, Xnew)

What will be the correct approach to implement gpu support on this model?

Hmm so I’m the current (lazy) maintainer of MLJLinearModels and there’s indeed no explicit support for GPU. Note that using GPU for regression seems a bit overkill but maybe you have a use case that requires it with giant data or something.

Some of the package relies on IterativeSolvers.jl which does support GPU (eg Conjugate Gradients · IterativeSolvers.jl) but it seems to require some care to ensure all vectors are on the GPU; I’ve not tried this; though if someone has a clear idea of what’s required, I can try help expose this.

Edit: note that some models that can be called via MLJ do have GPU support (e.g. Flux) but I don’t know whether this “just works” or requires careful data handling by MLJ, @ablaom or @samuel_okon should be able to give clearer explanations on that front


Yes GPU support in MLJ is model-specific. (Meta algorithms, such as hyper-parameter tuning, do support multi-processor and multi-threading but there’s no real sense in supporting GPU for this, I’d say.)

The MLJFlux models support training on a GPU. You present your data as normal and the transfer to the GPU is handled under the hood. You enable the GPU for training by setting the hyperparameter acceleration=CUDALibs().


Thank you @tlienart for the response and explanation. Yeah, I am working with large datasets and was trying to reduce the time. And, I will have a look into IterativeSolvers.jl and update on my findings on GPU support.

Thanks for maintaining the package, amazing work !!!

1 Like

Thanks @ablaom for the response and clearing the doubts regarding gpu support. I did manage to run my Linear model on multithreading that really reduced almost 50% of the time.
I will look into MLJFlux for more understanding on gpu implementation.
I will try changing my data to gpu compatible i.e. CuArray and see if it works, will update my findings!

1 Like