A problem I have been battling with for several weeks now. I am interested in fitting a model that simulates a neural ode ydot = g(y) and then passes this output y(t) pointwise in time through another neural network y(t) → f(y(t)). So I have two neural networks f and g that need to be trained in this problem. If I fix the parameters of f and only train g, then one gradient computation using adjoint methods takes about 35ms which is good. However, if I don’t fix f and compute the gradient over the parameters in f and g jointly using adjoint methods, this computation just hangs. I am at a complete loss as to why this is happening, a MWE is shown below:
cd(@_DIR_)
using Pkg
Pkg.activate(".")
using Lux
using ComponentArrays
using Zygote
using SciMLSensitivity
using ForwardDiff
using DifferentialEquations
using Random
using BenchmarkTools
# Neural ODE
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
solver::So
sensealg::Se
tspan::T
saveat::Sa
end
function NeuralODE(model::Lux.AbstractExplicitLayer;
solver=Tsit5(),
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
tspan=(0.0f0, 1.0f0),
saveat=[])
return NeuralODE(model, solver, sensealg, tspan, saveat)
end
function (n::NeuralODE)(ps, st; tspan=n.tspan, saveat=n.saveat)
function n_ode(u, p, t)
du, _ = n.model(u, p, st)
return du
end
prob = ODEProblem(ODEFunction(n_ode), zeros(2), tspan, ps)
return solve(prob, n.solver; sensealg=n.sensealg, saveat=saveat)
end
# Model which simulates ydot = g(y) and then composes x(t) = f(y(t))
struct ComposedModel{F <: Lux.AbstractExplicitLayer, GODE <: Lux.AbstractExplicitLayer, H <: Int64, P <: Bool} <:
Lux.AbstractExplicitContainerLayer{(:f, :gode)}
f::F
gode::GODE
end
function ComposedModel(f::Lux.AbstractExplicitLayer, g::Lux.AbstractExplicitLayer;
solver=Tsit5(),
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
tspan=(0.0f0, 1.0f0),
saveat=[])
gode = NeuralODE(g; solver=solver, sensealg=sensealg, tspan=tspan, saveat=saveat)
return ComposedModel(f, gode)
end
function (n::ComposedModel)(ps, st; tspan=n.gode.tspan, saveat=n.gode.saveat)
return n.f(Array(n.gode(ps.gode, st.gode; tspan=tspan, saveat=saveat)), ps.f, st.f)[1]
end
tspan = (0.0f0, 10.0f0)
Nt = 100
saveat = LinRange{Float32}(tspan[1], tspan[2], Nt)
# Define Neural ODE
hidden_nodes = 10
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, 1; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))
g = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))
# Sensitivity algorithm for AD
#sensealg = ForwardDiffSensitivity()
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
#sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(true))
# Case 1: Fix f and only optimize g
model = NeuralODE(g; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)
rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)
ps_f, st_f = Lux.setup(rng, f)
ps_f = ComponentArray(ps_f)
function cost(ps)
return sum(abs2, f(model(ps, st), ps_f, st_f)[1]) / Nt
end
@benchmark begin
l, back = pullback(p -> cost(p), ps)
gs = back(one(l))[1]
end
# Case 2: Optimize both f and g (this hangs!!)
model = ComposedModel(f, g; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)
rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)
function cost(ps)
return sum(abs2, model(ps, st)) / Nt
end
@benchmark begin
l, back = pullback(p -> cost(p), ps)
gs = back(one(l))[1]
end