@bkuwahara what you do is change the parameters to component arrays, and then concenate them as follows:
p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
p = Lux.ComponentArray{Float32}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)
The resulting p now works like p.p1, p.p2, and p.scaling_factor for the different parts, and Lux will use each individual piece effectively.
This is also in the docs now: Simultaneous Fitting of Multiple Neural Networks · SciMLSensitivity.jl
using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Random
rng = Random.default_rng()
function fitz(du,u,p,t)
v,w = u
a,b,τinv,l = p
du[1] = v - v^3/3 -w + l
du[2] = τinv*(v + a - b*w)
end
p_ = Float32[0.7,0.8,1/12.5,0.5]
u0 = [1f0;1f0]
tspan = (0f0,10f0)
prob = ODEProblem(fitz,u0,tspan,p_)
sol = solve(prob, Tsit5(), saveat = 0.5 )
# Ideal data
X = Array(sol)
Xₙ = X + Float32(1e-3)*randn(eltype(X), size(X)) #noisy data
# For xz term
NN_1 = Lux.Chain(Lux.Dense(2, 16, tanh), Lux.Dense(16, 1))
p1,st1 = Lux.setup(rng, NN_1)
# for xy term
NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1))
p2, st2 = Lux.setup(rng, NN_2)
scaling_factor = 1f0
p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
p = Lux.ComponentArray{Float32}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)
function dudt_(u,p,t)
v,w = u
z1 = NN_1([v,w], p.p1, st1)[1]
z2 = NN_2([v,w,t], p.p2, st2)[1]
[z1[1],p.scaling_factor*z2[1]]
end
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(),saveat = sol.t)
function predict(θ)
Array(solve(prob_nn, Vern7(), p=θ, saveat = sol.t,
abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
# No regularisation right now
function loss(θ)
pred = predict(θ)
sum(abs2, Xₙ .- pred), pred
end
loss(p)
const losses = []
callback(θ,l,pred) = begin
push!(losses, l)
if length(losses)%50==0
println(losses[end])
end
false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1_uode = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 500)
optprob2 = Optimization.OptimizationProblem(optf, res1_uode.u)
res2_uode = Optimization.solve(optprob2, BFGS(), maxiters = 10000, callback = callback)