# Composing a Neural ODE with another Neural Network

A problem I have been battling with for several weeks now. I am interested in fitting a model that simulates a neural ode ydot = g(y) and then passes this output y(t) pointwise in time through another neural network y(t) → f(y(t)). So I have two neural networks f and g that need to be trained in this problem. If I fix the parameters of f and only train g, then one gradient computation using adjoint methods takes about 35ms which is good. However, if I don’t fix f and compute the gradient over the parameters in f and g jointly using adjoint methods, this computation just hangs. I am at a complete loss as to why this is happening, a MWE is shown below:

``````cd(@_DIR_)

using Pkg
Pkg.activate(".")

using Lux
using ComponentArrays
using Zygote
using SciMLSensitivity
using ForwardDiff
using DifferentialEquations
using Random
using BenchmarkTools

# Neural ODE
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
solver::So
sensealg::Se
tspan::T
saveat::Sa
end

function NeuralODE(model::Lux.AbstractExplicitLayer;
solver=Tsit5(),
tspan=(0.0f0, 1.0f0),
saveat=[])
return NeuralODE(model, solver, sensealg, tspan, saveat)
end

function (n::NeuralODE)(ps, st; tspan=n.tspan, saveat=n.saveat)
function n_ode(u, p, t)
du, _ = n.model(u, p, st)
return du
end
prob = ODEProblem(ODEFunction(n_ode), zeros(2), tspan, ps)
return solve(prob, n.solver; sensealg=n.sensealg, saveat=saveat)
end

# Model which simulates ydot = g(y) and then composes x(t) = f(y(t))
struct ComposedModel{F <: Lux.AbstractExplicitLayer, GODE <: Lux.AbstractExplicitLayer, H <: Int64, P <: Bool} <:
Lux.AbstractExplicitContainerLayer{(:f, :gode)}
f::F
gode::GODE
end

function ComposedModel(f::Lux.AbstractExplicitLayer, g::Lux.AbstractExplicitLayer;
solver=Tsit5(),
tspan=(0.0f0, 1.0f0),
saveat=[])
gode = NeuralODE(g; solver=solver, sensealg=sensealg, tspan=tspan, saveat=saveat)
return ComposedModel(f, gode)
end

function (n::ComposedModel)(ps, st; tspan=n.gode.tspan, saveat=n.gode.saveat)
return n.f(Array(n.gode(ps.gode, st.gode; tspan=tspan, saveat=saveat)), ps.f, st.f)[1]
end

tspan = (0.0f0, 10.0f0)
Nt = 100
saveat = LinRange{Float32}(tspan[1], tspan[2], Nt)

# Define Neural ODE
hidden_nodes = 10
weight_init_mag = 0.1

f = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, 1; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

g = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

#sensealg = ForwardDiffSensitivity()

# Case 1: Fix f and only optimize g
model = NeuralODE(g; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)

ps_f, st_f = Lux.setup(rng, f)
ps_f = ComponentArray(ps_f)

function cost(ps)
return sum(abs2, f(model(ps, st), ps_f, st_f)[1]) / Nt
end

@benchmark begin
l, back = pullback(p -> cost(p), ps)
gs = back(one(l))[1]
end

# Case 2: Optimize both f and g (this hangs!!)
model = ComposedModel(f, g; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)

function cost(ps)
return sum(abs2, model(ps, st)) / Nt
end

@benchmark begin
l, back = pullback(p -> cost(p), ps)
gs = back(one(l))[1]
end
``````

The problem has nothing to do with adjoint methods or otherwise. Your `ComposedModel` definition gets into an infinite recursion because of the types.

``````struct ComposedModel{F <: Lux.AbstractExplicitLayer, GODE <: Lux.AbstractExplicitLayer,
H <: Int64, P <: Bool} <: Lux.AbstractExplicitContainerLayer{(:f, :gode)}
f::F
gode::GODE

function ComposedModel(f::Lux.AbstractExplicitLayer, g::Lux.AbstractExplicitLayer;