MLJ - user defined models - supporting weights - issue?

Hi- I’ve been using MLJ successfully to create some simple custom models using the MLJModelInterface. Recently, I’ve attempted to expand the models to support sample weights for the purpose of implementing a simple adaptive LASSO. However, MLJ keeps throwing an unsupported keyword error when I try to use “w” in the “fit” in following method shown in MLJ docs. Example code is below. My original code involves kernels – I’ve removed all the code that is superfluous to my question/issue but still kept it in a form that is relevant for what I’m trying to do… reason for what otherwise would be very questionable construct!

#set up simplistic construct
mutable struct YourModel <: MLJModelInterface.Deterministic
    lambda::Float64
    mdl::Deterministic
end
function YourModel(; lambda=1e-4, mdl=RidgeRegressor)
    model   = YourModel(lambda, mdl)
    message = MLJModelInterface.clean!(model)
    isempty(message) || @warn message
    return model
end
function MLJModelInterface.clean!(m::YourModel)
    warning = ""
    if m.lambda <= 0
        warning *= "Parameter `lambda` expected to be positive, resetting to 1e-4"
        m.lambda = 1e-4
    end
    return warning
end

function MLJModelInterface.fit(m::YourModel, verbosity, X, y, w=nothing)
    
    #rescale for adaptive lasso approach, if weights passed
    if w!=nothing 
        Xw = table(matrix(X).*w')
    else
        Xw = X
    end
    
    #create machine based on passed model and parameters
    mach = machine(m.mdl, Xw, y)

    #fit the machine, set verbosity to 0 to reduce messages
    core_fitresult = MLJ.fit!(mach, verbosity=0)

    #return fitresult, Xold, weights, cache, report
    cache  = fitted_params(mach) #nothing
    report = nothing 
    return (core_fitresult, X, w), cache, report
end

function MMI.predict(m::YourModel, fitresult, Xnew)
    #pull out core fitresult, X, and weights 
    core_fitresult, Xold, w = fitresult
    
    #rescale for adaptive lasso approach
    if w!=nothing 
        Xw = table(matrix(Xnew).*w') 
    else
        Xw = Xnew
    end
        
    return MLJ.predict(core_fitresult, Xw)
end


#declare target-specific composite model (simplistic construct)
mlj_c_mdl = YourModel(mdl= (@load LassoRegressor pkg="ScikitLearn" verbosity=0)())
#set model lambda 
mlj_c_mdl.mdl.alpha=1e-5


Xt, y = @load_boston;
X = matrix(Xt)

#re-split into test, train indices... getting indices for each row
train_idx, test_idx = partition(eachindex(y[:,1]), 0.7, shuffle=true); #70:20:10 split

#create machine
opt_mc = machine(mlj_c_mdl, table(X[train_idx,:]), y[train_idx])
#fit the machines
MLJ.fit!(opt_mc, verbosity=0)
#predict
yhat = MLJ.predict(opt_mc, table(X[test_idx,:]))
#mean error
current_rms = mean(sqrt.((yhat- y[test_idx]).^(2)))
@printf("RMS is: \$%0.2fk dollars",current_rms)

Works fine, above, for first instance, then fails below when trying to pass “w”

#reweight using coef from first Lasso
#create machine with weight passed... this fails
opt_mc = machine(mlj_c_mdl, table(X[train_idx,:]), y[train_idx], w=opt_mc.cache.coef)
#fit the machines
MLJ.fit!(opt_mc, verbosity=0)
#predict
yhat = MLJ.predict(opt_mc, table(X[test_idx,:]))

#mean error
current_rms = mean(sqrt.((yhat- y[test_idx]).^(2)))
@printf("RMS is: \$%0.2fk dollars",current_rms)

Error I get is:

MethodError: no method matching Machine(::YourModel, ::Source, ::Source; w=[-0.12638333067060678, 0.04755297485286125, 0.013791743556274645, -20.950523452463056, 2.9617997607112896, 0.028587424449456347, -1.5399207663605692, 0.3913773956819799, -0.013500494268412846, -1.0609285377525144, 0.00890119347944057, -0.6414059719669527])
Closest candidates are:
  Machine(::M, ::AbstractNode...; cache) where M<:Model at /home/XXX/.julia/packages/MLJBase/KWyqX/src/machines.jl:32 got unsupported keyword argument "w"

I think you need to pass in the weights as a positional argument:

machine(model, X, y, w)

I thought there used to be examples like that in the documentation, but I can’t find them at the moment.

Ah! Yes, you’re right. I did not need to specify ‘w’. That’s a bit boneheaded on my part. Thank you!

Easy mistake to make, since that feature seems to be short on documentation.

Thanks @CameronBieganek .

1 Like