Fitting a multiple input Flux.jl model with learning networks in MLJ.jl

Hey there, I have tried to use MLJ.jl to fit a Flux model that takes in multiple inputs. While Flux.jl documentation demonstrates how to build such a model, I unfortunately cannot make it work within MLJ.

Here is a MWE, attempting to make use of learning networks.

using MLJ
using Flux, MLJFlux
using UnPack

# Defining custom Flux model
struct PowerModelNN{NN1, NN2}
Flux.@functor PowerModelNN

function (m::PowerModelNN)(xs::Tuple)
    a, pred = xs
    return m.nn1(pred) .* a .^ m.nn2(pred)
(m::PowerModelNN)(xs...) = m(xs)

# Defining custom composite model
mutable struct MultipleInputCompositeModel <: MLJ.DeterministicNetworkComposite

import MLJBase
function MLJBase.prefit(m::MultipleInputCompositeModel, verbosity::Int, X, y)
    @unpack fsa_model, fspred_model, std_model, sar_model = m

    Xs = source(X)
    ys = source(y)
    as = MLJ.transform(machine(fsa_model, Xs), Xs)
    preds = MLJ.transform(machine(fspred_model, Xs), Xs)
    pred_std = MLJ.transform(machine(std_model, preds), preds)

    # Here is the bit which is failing
    sar_machine = machine(sar_model, (as, pred_std), ys)
    ŷ = predict(sar_machine, (as, pred_std))
    return (; predict = ŷ)

# Building multiple input model and machine
NN = MLJ.@load NeuralNetworkRegressor pkg=MLJFlux

nn1 = Chain(Dense(3, 10), Dense(10, 1, relu, bias=false))
nn2 = Chain(Dense(3, 10), Dense(10, 1, relu, bias=false))
builder = MLJFlux.@builder PowerModelNN(nn1, nn2)

mdl = MultipleInputCompositeModel(FeatureSelector(features=[:a]),
                                FeatureSelector(features=[:a], ignore = true),
                                NN(builder=builder, loss = Flux.poisson_loss))

# Generating synthetic data
using DataFrames
X = DataFrame(:a => rand(100), :x1 => randn(100), :x2 => randn(100))
y = exp.(2 * X.x1 .+ X.x2) .* X.a .^ exp.(X.x2)

# Fitting the machine
cm = machine(mdl, X, y)

I get the error Mixing concrete data with Nodetraining arguments is not allowed.. I believe that this comes from the fact that one cannot use a tuple of Node as an argument for a model or machine. Does anyone have an idea of a solution to make this work?


1 Like

I haven’t tried to reproduce, but it sounds like you just want to join (horizontally concatenate) the two tables represented by the nodes as and predict_std, yes?

If we assume the tables are DataFrames with non-intersecting column names, then you can try replacing (as, predict_std) with hcat(X1, X2), which hopefully works, because hcat is overloaded to work for nodes out of the box (long-hand would be node(hcat, as, predict_std)).

I don’t know the latest recommendation for general tables, but you can follow this link to get a solution: Add method to horizontally concatenate two (or more) tables of possibly different type · Issue #30 · JuliaData/TableOperations.jl · GitHub

1 Like

Thanks for the swift reply!

Unfortunately, your solution does not work, as the Flux model NN should take as an argument a tuple, and not a single table. This is required as the entries as and predict_std are processed in a different fashion in the internals of the flux model.

Based on your suggestion, I tried node(tuple, as, predict_std) but this unfortunately does now work either.

1 Like

Sorry, I indeed misunderstood your problem.

Your current approach will not work because you are violating an assumption about how MLJFlux.NeuralNetworkRegressor works. The Flux model m created by the builder can only be called on matrices (and vectors), not tuples of matrices.

In more detail, the X in a call such as machine(NeuralNetworkRegressor(...), X, y) must be a matrix or a table with p columns, say. This X is converted to vector of p x b matrices, where b is the batch size, and in training m is called on each of these matrices.

In case you are curious, here are some relevant sections of the code base:

1 Like

Ok, thanks a lot for pointing me to these pieces of code. I’ll just try to overload and MLJFlux.collate with a new MultiInputNeuralNetworkRegressor. I’ll post the solution when I found it, and maybe make a PR if relevant.