Using a trained MLJ model for prediction on non-Table objects

After training an MLJ model on data that meets the Tables.jl interface requirements, I’m interested in then deploying these models into an environment where predict is called on objects that don’t necessarily meet the Tables.jl interface. For instance, on a single dataframerow, a struct, or a hierarchical struct with a special getproperty function:

using DataFrames
using MLJ
import MLJDecisionTreeInterface.DecisionTreeClassifier as Tree

# Training Data
iris = DataFrame(load_iris());
y, X = unpack(iris, ==(:target); rng=123)

# Model Training
tree = Tree()
mach = machine(tree, X, y)
train, test = partition(eachindex(y), 0.7)
fit!(mach, rows=train)

# Predict on single observation, gives error.
predict(mach, X[test[1],:])

# Predict on a struct, gives error.
struct PlantFlat
    sepal_length::Float64
    sepal_width::Float64
    petal_length::Float64
    petal_width ::Float64
end
plant = PlantFlat(1.0,1.0,1.0,1.0)
predict(mach, plant)

# Predict on a hierarchical struct.
struct PlantStructure
    length::Float64
    width::Float64
end
struct PlantHierarchical
    sepal::PlantStructure
    petal::PlantStructure
end
function Base.getproperty(plant::PlantHierarchical, name::Symbol)
    name_str = String(name)
    if startswith(name_str, "sepal_")
        sepal = getfield(plant, :sepal)
        field = Symbol(chopprefix(name_str, "sepal_"))
        return getproperty(sepal, field)
    elseif startswith(name_str, "petal_")
        petal = getfield(plant, :petal)
        field = Symbol(chopprefix(name_str, "petal_"))
        return getproperty(petal, field)
    else
        throw(ArgumentError("Invalid property name: $name"))
    end
end
plant = PlantHierarchical(PlantStructure(1.0,1.0),PlantStructure(1.0,1.0))
predict(mach, plant)

I thought the getproperty function might be enough, but will I need to overload all the Tables.jl interface requirements to use a hierarchical object? Thanks for the help!

MLJ doesn’t really provide tools for you to do this kind of thing.

To predict on a single row of a DataFrame, you could try

predict(mach, X[[test[1]],:])

For the other case, you could either implement the Tables.jl interface, or insert a function that converts your custom format to a table, e.g., a named tuple of vectors would do.

Perhaps you could say a little more about your use case, which seems unusual. Why is the from of your production data different from your training data?

1 Like