Simultaneous training of multiple neural networks with Lux

Suppose I’m trying to train a neural ODE composed of two distinct neural networks, for instance because I want to preserve information about which model states influence which other states.

This worked easily with the old FastChain/DiffEqFlux API by just concatenating the networks’ parameters into a single vector. However, I’m not sure how to do this using Lux.

Here’s what I have so far:

using Lux, DiffEqFlux, Zygote
using Optimization, OptimizationOptimJL, OptimizationFlux, OptimizationPolyalgorithms
using DifferentialEquations
using LinearAlgebra
using Plots
using Random; rng = Random.default_rng()


function trueode(du, u, p, t)
    du[1] = u[2]
    du[2] = -u[1]
    nothing
end

u0 = Float32[0; 1]
tspan = Float32[0.0, 10]
p_ = SciMLBase.NullParameters()

prob = ODEProblem(trueode, u0, tspan, p_)
sol = solve(prob, Tsit5(), saveat=0.1)
data = Array(sol)
tsteps = sol.t


struct NDE2Network{du1, du2} <:
	Lux.AbstractExplicitContainerLayer{(:network1, :network2)}
	network1::du1
	network2::du2
end


input_size = output_size = 1
function NDE2Network(hidden_dims)
	return NDE2Network(
		Lux.Chain(
			Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size)),
		Lux.Chain(
			Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size)),
	)
end



function (NN::NDE2Network)(du, u, p, t)
	du[1] = NN.network1([u[2]], p.network1, st.network1)[1][1]
	du[2] = NN.network2([u[1]], p.network2, st.network2)[1][1]
	nothing
end

network = NDE2Network(10)
p, st = Lux.setup(rng, network)
u0 = data[:,1]
prob_nn = ODEProblem(network, u0, tspan, Lux.ComponentArray(p))


function predict(p)
	Array(solve(prob_nn, Tsit5(), saveat=tsteps))
end

function loss(p)
	pred = predict(p)
	sum(abs2, pred .- data)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, u) -> loss(p), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res = Optimization.solve(optprob, PolyOpt(), maxiters=500)

The prediction and loss functions work well here, but when I try to run the last line (either with PolyOpt or with ADAM), Julia generates a truly gargantuan wall of output (several hundred lines) and crashes. The first few lines of output are

Function Attrs: uwtable willreturn mustprogress
define internal fastcc void @preprocess_julia_NDE2Network_11750([2 x [1 x [2 x { i64, i64 }]]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(96) %0, {} addrspace(10)* nonnull align 16 dereferenceable(40) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2, { {} addrspace(10)* } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %3) unnamed_addr #9 !dbg !153 {
top:
  %4 = alloca [1 x [2 x i64]], align 8
  %5 = alloca [1 x [2 x i64]], align 8
  %6 = call {}*** @julia.get_pgcstack() #8
  %7 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !154
  %8 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %7 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !154
  %9 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %8, i64 0, i32 1, !dbg !154
  %10 = load i64, i64 addrspace(11)* %9, align 8, !dbg !154, !tbaa !22, !range !25       
  %11 = icmp ugt i64 %10, 1, !dbg !154
  br i1 %11, label %idxend, label %oob, !dbg !154

L27:                                              ; preds = %idxend
  %12 = addrspacecast [1 x [2 x i64]]* %4 to [1 x [2 x i64]] addrspace(11)*, !dbg !156   
  %13 = call fastcc nonnull {} addrspace(10)* @julia_throw_boundserror_11754({} addrspace(10)* nonnull align 16 dereferenceable(40) %56, [1 x [2 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %12) #10, !dbg !156
  unreachable, !dbg !156

Where am I going wrong here?

You want to build a ComponentArray that connects the two. Like:

CompoinentArray(layer_1 = Lux.ComponentArray(p1),layer_2 = Lux.ComponentArray(p2))

We will document this once more of the docs are built out

@avikpal anything else to add?

1 Like

Still getting an error, though at least the program doesn’t crash. Here is my implementation (lines generating training data are same as previously):

input_size = output_size = 1
hidden_dims = 5
network1 = Lux.Chain(
	Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size))
network2 = Lux.Chain(
	Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size))


function nde(du, u, p, t)
	du[1] = network1([u[2]], p.layer_1, st1)[1][1]
	du[2] = network2([u[1]], p.layer_2, st2)[1][1]
	nothing
end


p1, st1 = Lux.setup(rng, network1)
p2, st2 = Lux.setup(rng, network2)
p_init = ComponentArray(layer_1 = Lux.ComponentArray(p1), layer_2 = Lux.ComponentArray(p2))
u0 = data[:,1]
prob_nn = ODEProblem(nde, u0, tspan, p_init)


function predict(p)
	Array(solve(prob_nn, Tsit5(), saveat=tsteps))
end

function loss(p)
	pred = predict(p)
	sum(abs2, pred .- data)
end

function callback(p, l)
	display(l)
	false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, u) -> loss(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res = Optimization.solve(optprob, ADAM(0.05), maxiters=500, callback=callback)

The prediction, loss, and callback still work, but now the last line gives
MethodError: Cannot convert an object of type Nothing to an object of type Float32
Any idea where this error might come from/how to correct it?

function predict(p)
    return Array(solve(prob_nn, Tsit5(); saveat=tsteps))
end

Your function has no dependence on p so it makes sense that the gradient is nothing wrt p. (I am low-key so happy this errored out :P)

See MNIST Classification using NeuralODE - Lux.jl. You either need to remake the problem inside predict or create it inside predict

@bkuwahara what you do is change the parameters to component arrays, and then concenate them as follows:

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)

p = Lux.ComponentArray{Float32}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)

The resulting p now works like p.p1, p.p2, and p.scaling_factor for the different parts, and Lux will use each individual piece effectively.

This is also in the docs now: Simultaneous Fitting of Multiple Neural Networks · DiffEqSensitivity.jl

using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Random

rng = Random.default_rng()
function fitz(du,u,p,t)
  v,w = u
  a,b,τinv,l = p
  du[1] = v - v^3/3 -w + l
  du[2] = τinv*(v +  a - b*w)
end

p_ = Float32[0.7,0.8,1/12.5,0.5]
u0 = [1f0;1f0]
tspan = (0f0,10f0)
prob = ODEProblem(fitz,u0,tspan,p_)
sol = solve(prob, Tsit5(), saveat = 0.5 )

# Ideal data
X = Array(sol)
Xₙ = X + Float32(1e-3)*randn(eltype(X), size(X))  #noisy data

# For xz term
NN_1 = Lux.Chain(Lux.Dense(2, 16, tanh), Lux.Dense(16, 1))
p1,st1 = Lux.setup(rng, NN_1)

# for xy term
NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1))
p2, st2 = Lux.setup(rng, NN_2)
scaling_factor = 1f0

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)

p = Lux.ComponentArray{Float32}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)

function dudt_(u,p,t)
    v,w = u
    z1 = NN_1([v,w], p.p1, st1)[1]
    z2 = NN_2([v,w,t], p.p2, st2)[1]
    [z1[1],p.scaling_factor*z2[1]]
end
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(),saveat = sol.t)

function predict(θ)
    Array(solve(prob_nn, Vern7(), p=θ, saveat = sol.t,
                         abstol=1e-6, reltol=1e-6,
                         sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end

# No regularisation right now
function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred), pred
end
loss(p)
const losses = []
callback(θ,l,pred) = begin
    push!(losses, l)
    if length(losses)%50==0
        println(losses[end])
    end
    false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)

optprob = Optimization.OptimizationProblem(optf, p)
res1_uode = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 500)

optprob2 = Optimization.OptimizationProblem(optf, res1_uode.u)
res2_uode = Optimization.solve(optprob2, BFGS(), maxiters = 10000, callback = callback)