I’m trying to train a RandomForestRegressor
using DecisionTree.jl
and RandomizedSearchCV
(contained in ScikitLearn.jl
) in Julia. Primary datasets like x_train
and y_train
etc. are provided in my google drive as well, So you can test it on your machine. The code is as follows:
using CSV
using DataFrames
using ScikitLearn: fit!, predict
using ScikitLearn.GridSearch: RandomizedSearchCV
using DecisionTree
x = CSV.read("x.csv", DataFrames.DataFrame)
x_test = CSV.read("x_test.csv", DataFrames.DataFrame)
y_train = CSV.read("y_train.csv", DataFrames.DataFrame)
mod = RandomForestRegressor()
param_dist = Dict("n_trees"=>[50 , 100, 200, 300],
"max_depth"=> [3, 5, 6 ,8 , 9 ,10])
model = RandomizedSearchCV(mod, param_dist, n_iter=10, cv=5)
fit!(model, Matrix(x), Matrix(DataFrames.dropmissing(y_train)))
predict(x_test)
This throws a MethodError
like this:
ERROR: MethodError: no method matching fit!(::RandomForestRegressor, ::Matrix{Float64}, ::Matrix{Float64})
Closest candidates are:
fit!(::ScikitLearn.Models.FixedConstant, ::Any, ::Any) at C:\Users\Shayan\.julia\packages\ScikitLearn\ssekP\src\models\constant_model.jl:26
fit!(::ScikitLearn.Models.ConstantRegressor, ::Any, ::Any) at C:\Users\Shayan\.julia\packages\ScikitLearn\ssekP\src\models\constant_model.jl:10
fit!(::ScikitLearn.Models.LinearRegression, ::AbstractArray{XT}, ::AbstractArray{yT}) where {XT, yT} at C:\Users\Shayan\.julia\packages\ScikitLearn\ssekP\src\models\linear_regression.jl:27
...
Stacktrace:
[1] _fit!(self::RandomizedSearchCV, X::Matrix{Float64}, y::Matrix{Float64}, parameter_iterable::Vector{Any})
@ ScikitLearn.Skcore C:\Users\Shayan\.julia\packages\ScikitLearn\ssekP\src\grid_search.jl:332
[2] fit!(self::RandomizedSearchCV, X::Matrix{Float64}, y::Matrix{Float64})
@ ScikitLearn.Skcore C:\Users\Shayan\.julia\packages\ScikitLearn\ssekP\src\grid_search.jl:748
[3] top-level scope
@ c:\Users\Shayan\Desktop\AUT\Thesis\test.jl:17
If you’re curious about the shape of the data:
julia> size(x)
(1550, 70)
julia> size(y_train)
(1550, 10)
How can I solve this problem? Any help would be appreciated.