Suppose I’m trying to train a neural ODE composed of two distinct neural networks, for instance because I want to preserve information about which model states influence which other states.
This worked easily with the old FastChain/DiffEqFlux API by just concatenating the networks’ parameters into a single vector. However, I’m not sure how to do this using Lux.
Here’s what I have so far:
using Lux, DiffEqFlux, Zygote
using Optimization, OptimizationOptimJL, OptimizationFlux, OptimizationPolyalgorithms
using DifferentialEquations
using LinearAlgebra
using Plots
using Random; rng = Random.default_rng()
function trueode(du, u, p, t)
du[1] = u[2]
du[2] = -u[1]
nothing
end
u0 = Float32[0; 1]
tspan = Float32[0.0, 10]
p_ = SciMLBase.NullParameters()
prob = ODEProblem(trueode, u0, tspan, p_)
sol = solve(prob, Tsit5(), saveat=0.1)
data = Array(sol)
tsteps = sol.t
struct NDE2Network{du1, du2} <:
Lux.AbstractExplicitContainerLayer{(:network1, :network2)}
network1::du1
network2::du2
end
input_size = output_size = 1
function NDE2Network(hidden_dims)
return NDE2Network(
Lux.Chain(
Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size)),
Lux.Chain(
Lux.Dense(input_size=>hidden_dims, tanh), Lux.Dense(hidden_dims=>output_size)),
)
end
function (NN::NDE2Network)(du, u, p, t)
du[1] = NN.network1([u[2]], p.network1, st.network1)[1][1]
du[2] = NN.network2([u[1]], p.network2, st.network2)[1][1]
nothing
end
network = NDE2Network(10)
p, st = Lux.setup(rng, network)
u0 = data[:,1]
prob_nn = ODEProblem(network, u0, tspan, Lux.ComponentArray(p))
function predict(p)
Array(solve(prob_nn, Tsit5(), saveat=tsteps))
end
function loss(p)
pred = predict(p)
sum(abs2, pred .- data)
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, u) -> loss(p), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res = Optimization.solve(optprob, PolyOpt(), maxiters=500)
The prediction and loss functions work well here, but when I try to run the last line (either with PolyOpt or with ADAM), Julia generates a truly gargantuan wall of output (several hundred lines) and crashes. The first few lines of output are
Function Attrs: uwtable willreturn mustprogress
define internal fastcc void @preprocess_julia_NDE2Network_11750([2 x [1 x [2 x { i64, i64 }]]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(96) %0, {} addrspace(10)* nonnull align 16 dereferenceable(40) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2, { {} addrspace(10)* } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(8) %3) unnamed_addr #9 !dbg !153 {
top:
%4 = alloca [1 x [2 x i64]], align 8
%5 = alloca [1 x [2 x i64]], align 8
%6 = call {}*** @julia.get_pgcstack() #8
%7 = bitcast {} addrspace(10)* %2 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !154
%8 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %7 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !154
%9 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %8, i64 0, i32 1, !dbg !154
%10 = load i64, i64 addrspace(11)* %9, align 8, !dbg !154, !tbaa !22, !range !25
%11 = icmp ugt i64 %10, 1, !dbg !154
br i1 %11, label %idxend, label %oob, !dbg !154
L27: ; preds = %idxend
%12 = addrspacecast [1 x [2 x i64]]* %4 to [1 x [2 x i64]] addrspace(11)*, !dbg !156
%13 = call fastcc nonnull {} addrspace(10)* @julia_throw_boundserror_11754({} addrspace(10)* nonnull align 16 dereferenceable(40) %56, [1 x [2 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %12) #10, !dbg !156
unreachable, !dbg !156
Where am I going wrong here?