# Simultaneous training of multiple neural networks with Lux

Suppose I’m trying to train a neural ODE composed of two distinct neural networks, for instance because I want to preserve information about which model states influence which other states.

This worked easily with the old FastChain/DiffEqFlux API by just concatenating the networks’ parameters into a single vector. However, I’m not sure how to do this using Lux.

Here’s what I have so far:

``````using Lux, DiffEqFlux, Zygote
using Optimization, OptimizationOptimJL, OptimizationFlux, OptimizationPolyalgorithms
using DifferentialEquations
using LinearAlgebra
using Plots
using Random; rng = Random.default_rng()

function trueode(du, u, p, t)
du[1] = u[2]
du[2] = -u[1]
nothing
end

u0 = Float32[0; 1]
tspan = Float32[0.0, 10]
p_ = SciMLBase.NullParameters()

prob = ODEProblem(trueode, u0, tspan, p_)
sol = solve(prob, Tsit5(), saveat=0.1)
data = Array(sol)
tsteps = sol.t

struct NDE2Network{du1, du2} <:
Lux.AbstractExplicitContainerLayer{(:network1, :network2)}
network1::du1
network2::du2
end

input_size = output_size = 1
function NDE2Network(hidden_dims)
return NDE2Network(
Lux.Chain(
Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size)),
Lux.Chain(
Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size)),
)
end

function (NN::NDE2Network)(du, u, p, t)
du[1] = NN.network1([u[2]], p.network1, st.network1)[1][1]
du[2] = NN.network2([u[1]], p.network2, st.network2)[1][1]
nothing
end

network = NDE2Network(10)
p, st = Lux.setup(rng, network)
u0 = data[:,1]
prob_nn = ODEProblem(network, u0, tspan, Lux.ComponentArray(p))

function predict(p)
Array(solve(prob_nn, Tsit5(), saveat=tsteps))
end

function loss(p)
pred = predict(p)
sum(abs2, pred .- data)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, u) -> loss(p), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res = Optimization.solve(optprob, PolyOpt(), maxiters=500)
``````

The prediction and loss functions work well here, but when I try to run the last line (either with PolyOpt or with ADAM), Julia generates a truly gargantuan wall of output (several hundred lines) and crashes. The first few lines of output are

``````Function Attrs: uwtable willreturn mustprogress
define internal fastcc void @preprocess_julia_NDE2Network_11750([2 x [1 x [2 x { i64, i64 }]]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(96) %0, {} addrspace(10)* nonnull align 16 dereferenceable(40) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2, { {} addrspace(10)* } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %3) unnamed_addr #9 !dbg !153 {
top:
%4 = alloca [1 x [2 x i64]], align 8
%5 = alloca [1 x [2 x i64]], align 8
%6 = call {}*** @julia.get_pgcstack() #8
%7 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !154
%8 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %7 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !154
%9 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %8, i64 0, i32 1, !dbg !154
%10 = load i64, i64 addrspace(11)* %9, align 8, !dbg !154, !tbaa !22, !range !25
%11 = icmp ugt i64 %10, 1, !dbg !154
br i1 %11, label %idxend, label %oob, !dbg !154

L27:                                              ; preds = %idxend
%12 = addrspacecast [1 x [2 x i64]]* %4 to [1 x [2 x i64]] addrspace(11)*, !dbg !156
%13 = call fastcc nonnull {} addrspace(10)* @julia_throw_boundserror_11754({} addrspace(10)* nonnull align 16 dereferenceable(40) %56, [1 x [2 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %12) #10, !dbg !156
unreachable, !dbg !156
``````

Where am I going wrong here?

You want to build a ComponentArray that connects the two. Like:

``````CompoinentArray(layer_1 = Lux.ComponentArray(p1),layer_2 = Lux.ComponentArray(p2))
``````

We will document this once more of the docs are built out

@avikpal anything else to add?

1 Like

Still getting an error, though at least the program doesn’t crash. Here is my implementation (lines generating training data are same as previously):

``````input_size = output_size = 1
hidden_dims = 5
network1 = Lux.Chain(
Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size))
network2 = Lux.Chain(
Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size))

function nde(du, u, p, t)
du[1] = network1([u[2]], p.layer_1, st1)[1][1]
du[2] = network2([u[1]], p.layer_2, st2)[1][1]
nothing
end

p1, st1 = Lux.setup(rng, network1)
p2, st2 = Lux.setup(rng, network2)
p_init = ComponentArray(layer_1 = Lux.ComponentArray(p1), layer_2 = Lux.ComponentArray(p2))
u0 = data[:,1]
prob_nn = ODEProblem(nde, u0, tspan, p_init)

function predict(p)
Array(solve(prob_nn, Tsit5(), saveat=tsteps))
end

function loss(p)
pred = predict(p)
sum(abs2, pred .- data)
end

function callback(p, l)
display(l)
false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, u) -> loss(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res = Optimization.solve(optprob, ADAM(0.05), maxiters=500, callback=callback)
``````

The prediction, loss, and callback still work, but now the last line gives
`MethodError: Cannot `convert` an object of type Nothing to an object of type Float32`
Any idea where this error might come from/how to correct it?

``````function predict(p)
return Array(solve(prob_nn, Tsit5(); saveat=tsteps))
end
``````

Your function has no dependence on `p` so it makes sense that the gradient is `nothing` wrt `p`. (I am low-key so happy this errored out :P)

See MNIST Classification using NeuralODE - Lux.jl. You either need to `remake` the problem inside `predict` or create it inside `predict`

@bkuwahara what you do is change the parameters to component arrays, and then concenate them as follows:

``````p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)

p = Lux.ComponentArray{Float32}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)
``````

The resulting `p` now works like `p.p1`, `p.p2`, and `p.scaling_factor` for the different parts, and Lux will use each individual piece effectively.

This is also in the docs now: Simultaneous Fitting of Multiple Neural Networks · DiffEqSensitivity.jl

``````using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Random

rng = Random.default_rng()
function fitz(du,u,p,t)
v,w = u
a,b,τinv,l = p
du[1] = v - v^3/3 -w + l
du[2] = τinv*(v +  a - b*w)
end

p_ = Float32[0.7,0.8,1/12.5,0.5]
u0 = [1f0;1f0]
tspan = (0f0,10f0)
prob = ODEProblem(fitz,u0,tspan,p_)
sol = solve(prob, Tsit5(), saveat = 0.5 )

# Ideal data
X = Array(sol)
Xₙ = X + Float32(1e-3)*randn(eltype(X), size(X))  #noisy data

# For xz term
NN_1 = Lux.Chain(Lux.Dense(2, 16, tanh), Lux.Dense(16, 1))
p1,st1 = Lux.setup(rng, NN_1)

# for xy term
NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1))
p2, st2 = Lux.setup(rng, NN_2)
scaling_factor = 1f0

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)

p = Lux.ComponentArray{Float32}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)

function dudt_(u,p,t)
v,w = u
z1 = NN_1([v,w], p.p1, st1)[1]
z2 = NN_2([v,w,t], p.p2, st2)[1]
[z1[1],p.scaling_factor*z2[1]]
end
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(),saveat = sol.t)

function predict(θ)
Array(solve(prob_nn, Vern7(), p=θ, saveat = sol.t,
abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end

# No regularisation right now
function loss(θ)
pred = predict(θ)
sum(abs2, Xₙ .- pred), pred
end
loss(p)
const losses = []
callback(θ,l,pred) = begin
push!(losses, l)
if length(losses)%50==0
println(losses[end])
end
false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)

optprob = Optimization.OptimizationProblem(optf, p)
res1_uode = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 500)

optprob2 = Optimization.OptimizationProblem(optf, res1_uode.u)
res2_uode = Optimization.solve(optprob2, BFGS(), maxiters = 10000, callback = callback)
``````