Hey there, I have tried to use MLJ.jl to fit a Flux model that takes in multiple inputs. While Flux.jl documentation demonstrates how to build such a model, I unfortunately cannot make it work within MLJ.
Here is a MWE, attempting to make use of learning networks.
using MLJ
using Flux, MLJFlux
using UnPack
# Defining custom Flux model
struct PowerModelNN{NN1, NN2}
nn1::NN1
nn2::NN2
end
Flux.@functor PowerModelNN
function (m::PowerModelNN)(xs::Tuple)
a, pred = xs
return m.nn1(pred) .* a .^ m.nn2(pred)
end
(m::PowerModelNN)(xs...) = m(xs)
# Defining custom composite model
mutable struct MultipleInputCompositeModel <: MLJ.DeterministicNetworkComposite
fsa_model
fspred_model
std_model
sar_model
end
import MLJBase
function MLJBase.prefit(m::MultipleInputCompositeModel, verbosity::Int, X, y)
@unpack fsa_model, fspred_model, std_model, sar_model = m
Xs = source(X)
ys = source(y)
as = MLJ.transform(machine(fsa_model, Xs), Xs)
preds = MLJ.transform(machine(fspred_model, Xs), Xs)
pred_std = MLJ.transform(machine(std_model, preds), preds)
# Here is the bit which is failing
sar_machine = machine(sar_model, (as, pred_std), ys)
ŷ = predict(sar_machine, (as, pred_std))
return (; predict = ŷ)
end
# Building multiple input model and machine
NN = MLJ.@load NeuralNetworkRegressor pkg=MLJFlux
nn1 = Chain(Dense(3, 10), Dense(10, 1, relu, bias=false))
nn2 = Chain(Dense(3, 10), Dense(10, 1, relu, bias=false))
builder = MLJFlux.@builder PowerModelNN(nn1, nn2)
mdl = MultipleInputCompositeModel(FeatureSelector(features=[:a]),
FeatureSelector(features=[:a], ignore = true),
Standardizer(),
NN(builder=builder, loss = Flux.poisson_loss))
# Generating synthetic data
using DataFrames
X = DataFrame(:a => rand(100), :x1 => randn(100), :x2 => randn(100))
y = exp.(2 * X.x1 .+ X.x2) .* X.a .^ exp.(X.x2)
# Fitting the machine
cm = machine(mdl, X, y)
fit!(cm)
I get the error Mixing concrete data with
Nodetraining arguments is not allowed.
. I believe that this comes from the fact that one cannot use a tuple of Node
as an argument for a model or machine. Does anyone have an idea of a solution to make this work?
Cheers!