[Design pattern] How to let my package internally use functions of another package it does know nothing about?

Hello, my objective is to have an imputer in my Machine Learning package that internally could use as actual imputer for each dimension (column) of the provided matrix any estimator that provides the “interface” mod=ModelType(options), fit!(mod,X,y) and predict(mod, X), for example DecisionTreeRegressor from the DecisionTree package.

Currently I have it correctly working with estimators provided by BetaML itself, but I have a problem when I want to use estimators from any other package for which BetaML knows nothing.

Indeed if the user does:


using BetaML, AnotherPackage

mod = BetaML.UniversalImputer(estimators=AnotherPackage.AnEstimatorType(its_options),other_parameters)

BetaML.fit!(mod,X)

X_imputed = BetaML.predict(mod)

Provided that in BetaML.fit! I have then, for each col of X, something similar to:

fit!(col_model,other_X_cols,col) and predict(col_model,X_row), then I have the MethodError that no methods matches fit!(::Type{AnotherPackage.AnEstimatorType},Matrix,Vector) and the same for predict().

How can I rearrange the code of my package so that it can use the correct fit! predict! functions of AnotherPackage without having AnotherPackage as a dependency or extension ? Perhaps with them being user-provided functions to the Universalmputer constructor ? Is there a way I could hide this complexity to the user ?

One solution would be to use the functions defined by a commun third party interface, in this case maybe StatsAPI.jl? That’s where I buy all my fit! functions.
But of course it only works if you and AnotherPackage.jl agree beforehand on the use of StatsAPI.jl.

User-provided functions are the ultimate generic solution, but unless you guarantee they will be in parametric type arguments of a struct and the whole code will specialize based on them, they may be slow if called multiple times with the expectation to be fast each time.

Sorry I haven’t understood, what do you mean by this ?

I indeed ended up adding the following two fields of my UniveralImputer:

"The function provided by the estimator package(s) to fit the model. It should take as fist argument the model itself, as second argument a matrix representing the features, and as third argument a vector representing the labels. This parameter is mandatory for non-BetaML estimators and can be a single value or a vector (one per dimension) in case of different packages used. [default: `BetaML.fit!`]" 
fit_function::Union{Vector{Function},Function}     = BetaML.fit!
"The function provided by the estimator package to predict the labels. It should take as fist argument the model itself and as second argument a matrix representing the features. This parameter is mandatory for non-BetaML estimators and can be a single value or a vector (one per dimension) in case of different packages used. [default: `BetaML.predict`]"
predict_function::Union{Vector{Function},Function} = BetaML.predict

I meant having the struct UniversalInputter{F1, F2} ... and the fields as fit_function::F1 = ... and predict_function::F2 = ... but if you can accept a vector of functions (instead of a Tuple) the performance loss is inevitable, it is better to keep the way you already did.

Ah, thanks… Yes, indeed it is not the functions with the model Struct doing the actual computation, but a depth-inner function, so it isn’t much of an issue…
Thanks for clarifying…

You don’t want to constrain your functions to be subtypes of Function, they might want to be callable structs. See the very end of this paragraph in the performance tips

Is there a way to consider any callable object ? Again, not too much important, I can leave it the field unspecified otherwise, as it is not in a performant critical part, at least compared to the inner algorithms…

You can make the type of your callable object inherit Function anyway, so I do not believe this is restrictive.

Except if the object is not yours to begin with

You can always wrap it with a zero cost abstraction.

julia> struct FunctionWrapper{T} <: Function; callable :: T; end

julia> function (fw::FunctionWrapper{T})(x...; kw...) where T; fw.callable(x...; kw...); end

(Note: the slurp and splat are not obligatorily zero-cost, but you could define specific callable functions for your specific type which would be. Here I just created a generic method that will work for any set of signatures.)

1 Like