After training an MLJ model on data that meets the Tables.jl interface requirements, I’m interested in then deploying these models into an environment where predict
is called on objects that don’t necessarily meet the Tables.jl interface. For instance, on a single dataframerow, a struct, or a hierarchical struct with a special getproperty
function:
using DataFrames
using MLJ
import MLJDecisionTreeInterface.DecisionTreeClassifier as Tree
# Training Data
iris = DataFrame(load_iris());
y, X = unpack(iris, ==(:target); rng=123)
# Model Training
tree = Tree()
mach = machine(tree, X, y)
train, test = partition(eachindex(y), 0.7)
fit!(mach, rows=train)
# Predict on single observation, gives error.
predict(mach, X[test[1],:])
# Predict on a struct, gives error.
struct PlantFlat
sepal_length::Float64
sepal_width::Float64
petal_length::Float64
petal_width ::Float64
end
plant = PlantFlat(1.0,1.0,1.0,1.0)
predict(mach, plant)
# Predict on a hierarchical struct.
struct PlantStructure
length::Float64
width::Float64
end
struct PlantHierarchical
sepal::PlantStructure
petal::PlantStructure
end
function Base.getproperty(plant::PlantHierarchical, name::Symbol)
name_str = String(name)
if startswith(name_str, "sepal_")
sepal = getfield(plant, :sepal)
field = Symbol(chopprefix(name_str, "sepal_"))
return getproperty(sepal, field)
elseif startswith(name_str, "petal_")
petal = getfield(plant, :petal)
field = Symbol(chopprefix(name_str, "petal_"))
return getproperty(petal, field)
else
throw(ArgumentError("Invalid property name: $name"))
end
end
plant = PlantHierarchical(PlantStructure(1.0,1.0),PlantStructure(1.0,1.0))
predict(mach, plant)
I thought the getproperty
function might be enough, but will I need to overload all the Tables.jl interface requirements to use a hierarchical object? Thanks for the help!