Simultaneous training of multiple neural networks with Lux

@bkuwahara what you do is change the parameters to component arrays, and then concenate them as follows:

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)

p = Lux.ComponentArray{Float32}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)

The resulting p now works like p.p1, p.p2, and p.scaling_factor for the different parts, and Lux will use each individual piece effectively.

This is also in the docs now: Simultaneous Fitting of Multiple Neural Networks · SciMLSensitivity.jl

using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Random

rng = Random.default_rng()
function fitz(du,u,p,t)
  v,w = u
  a,b,τinv,l = p
  du[1] = v - v^3/3 -w + l
  du[2] = τinv*(v +  a - b*w)
end

p_ = Float32[0.7,0.8,1/12.5,0.5]
u0 = [1f0;1f0]
tspan = (0f0,10f0)
prob = ODEProblem(fitz,u0,tspan,p_)
sol = solve(prob, Tsit5(), saveat = 0.5 )

# Ideal data
X = Array(sol)
Xₙ = X + Float32(1e-3)*randn(eltype(X), size(X))  #noisy data

# For xz term
NN_1 = Lux.Chain(Lux.Dense(2, 16, tanh), Lux.Dense(16, 1))
p1,st1 = Lux.setup(rng, NN_1)

# for xy term
NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1))
p2, st2 = Lux.setup(rng, NN_2)
scaling_factor = 1f0

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)

p = Lux.ComponentArray{Float32}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)

function dudt_(u,p,t)
    v,w = u
    z1 = NN_1([v,w], p.p1, st1)[1]
    z2 = NN_2([v,w,t], p.p2, st2)[1]
    [z1[1],p.scaling_factor*z2[1]]
end
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(),saveat = sol.t)

function predict(θ)
    Array(solve(prob_nn, Vern7(), p=θ, saveat = sol.t,
                         abstol=1e-6, reltol=1e-6,
                         sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end

# No regularisation right now
function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred), pred
end
loss(p)
const losses = []
callback(θ,l,pred) = begin
    push!(losses, l)
    if length(losses)%50==0
        println(losses[end])
    end
    false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)

optprob = Optimization.OptimizationProblem(optf, p)
res1_uode = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 500)

optprob2 = Optimization.OptimizationProblem(optf, res1_uode.u)
res2_uode = Optimization.solve(optprob2, BFGS(), maxiters = 10000, callback = callback)
1 Like