Composing a Neural ODE with another Neural Network

A problem I have been battling with for several weeks now. I am interested in fitting a model that simulates a neural ode ydot = g(y) and then passes this output y(t) pointwise in time through another neural network y(t) → f(y(t)). So I have two neural networks f and g that need to be trained in this problem. If I fix the parameters of f and only train g, then one gradient computation using adjoint methods takes about 35ms which is good. However, if I don’t fix f and compute the gradient over the parameters in f and g jointly using adjoint methods, this computation just hangs. I am at a complete loss as to why this is happening, a MWE is shown below:

cd(@_DIR_)

using Pkg
Pkg.activate(".")

using Lux
using ComponentArrays
using Zygote
using SciMLSensitivity
using ForwardDiff
using DifferentialEquations
using Random
using BenchmarkTools


# Neural ODE
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa} <:
    Lux.AbstractExplicitContainerLayer{(:model,)}
  model::M
  solver::So
  sensealg::Se
  tspan::T
  saveat::Sa
end

function NeuralODE(model::Lux.AbstractExplicitLayer;
  solver=Tsit5(),
  sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
  tspan=(0.0f0, 1.0f0),
  saveat=[])
  return NeuralODE(model, solver, sensealg, tspan, saveat)
end

function (n::NeuralODE)(ps, st; tspan=n.tspan, saveat=n.saveat)
  function n_ode(u, p, t)
    du, _ = n.model(u, p, st)
    return du
  end
  prob = ODEProblem(ODEFunction(n_ode), zeros(2), tspan, ps)
  return solve(prob, n.solver; sensealg=n.sensealg, saveat=saveat)
end


# Model which simulates ydot = g(y) and then composes x(t) = f(y(t))
struct ComposedModel{F <: Lux.AbstractExplicitLayer, GODE <: Lux.AbstractExplicitLayer, H <: Int64, P <: Bool} <:
  Lux.AbstractExplicitContainerLayer{(:f, :gode)}
  f::F
  gode::GODE
end

function ComposedModel(f::Lux.AbstractExplicitLayer, g::Lux.AbstractExplicitLayer;
  solver=Tsit5(),
  sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
  tspan=(0.0f0, 1.0f0),
  saveat=[])
  gode = NeuralODE(g; solver=solver, sensealg=sensealg, tspan=tspan, saveat=saveat)
  return ComposedModel(f, gode)
end

function (n::ComposedModel)(ps, st; tspan=n.gode.tspan, saveat=n.gode.saveat)
  return n.f(Array(n.gode(ps.gode, st.gode; tspan=tspan, saveat=saveat)), ps.f, st.f)[1]
end


tspan = (0.0f0, 10.0f0)
Nt = 100
saveat = LinRange{Float32}(tspan[1], tspan[2], Nt)

# Define Neural ODE
hidden_nodes = 10
weight_init_mag = 0.1

f = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, 1; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))


g = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

# Sensitivity algorithm for AD
#sensealg = ForwardDiffSensitivity()
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
#sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(true))




# Case 1: Fix f and only optimize g
model = NeuralODE(g; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)

ps_f, st_f = Lux.setup(rng, f)
ps_f = ComponentArray(ps_f)

function cost(ps)
    return sum(abs2, f(model(ps, st), ps_f, st_f)[1]) / Nt
end

@benchmark begin
    l, back = pullback(p -> cost(p), ps)
    gs = back(one(l))[1]
end




# Case 2: Optimize both f and g (this hangs!!)
model = ComposedModel(f, g; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)

function cost(ps)
    return sum(abs2, model(ps, st)) / Nt
end

@benchmark begin
    l, back = pullback(p -> cost(p), ps)
    gs = back(one(l))[1]
end

The problem has nothing to do with adjoint methods or otherwise. Your ComposedModel definition gets into an infinite recursion because of the types.

struct ComposedModel{F <: Lux.AbstractExplicitLayer, GODE <: Lux.AbstractExplicitLayer,
    H <: Int64, P <: Bool} <: Lux.AbstractExplicitContainerLayer{(:f, :gode)}
    f::F
    gode::GODE

    function ComposedModel(f::Lux.AbstractExplicitLayer, g::Lux.AbstractExplicitLayer;
        solver=Tsit5(), sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
        tspan=(0.0f0, 1.0f0), saveat=[])
        gode = NeuralODE(g; solver, sensealg, tspan, saveat)
        # I don't know what H and P are. Putting random values rn.
        return new{typeof(f), typeof(gode), 5, Bool}(f, gode)
    end
end

Thank you yes, it was a trivial error!