All Adjoint Methods Fail on Simple Neural ODE Example

Hello, I have an MRE where if I use any sensitivity algorithm apart from ForwardDiffSensitivity() I get an error. In particular, when I use the sensitivity algorithm InterpolatingAdjoint(autojacvec=ZygoteVJP()) I get the following error " MethodError: no method matching vec(::Nothing)".

Here is my MRE:

cd(@__DIR__)

using Pkg
Pkg.activate(".")

using Lux
using ComponentArrays
using Zygote
using ForwardDiff
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Random
using CairoMakie

# Simulate Data
function true_ode(du,u, p, t)
    x, y = u
    du[1] = -y
    du[2] = x
end

# number of simulated trajectories
Ns = 50
u0s = randn(2, Ns)

tspan = (0.0, 10.0)
Nt = 100
saveat = LinRange(tspan[1], tspan[2], Nt)

sols = zeros(2*Nt, Ns)
for i = 1:Ns
    println(i)
    prob = ODEProblem(true_ode, u0s[:, i], tspan)
    sols[:, i] = Array(solve(prob, Tsit5(); saveat=saveat))[:]
end


# Define Neural ODE
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(2, 10, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(10, 10, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(10, 10, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(10, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, f)
ps = ComponentArray(ps)

function f_ode(du, u, p, t)
    du, _ = f(u, p, st)
end

# Define Neural ODE Objective Function
#sensealg = ForwardDiffSensitivity()
#sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP())
#sensealg = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())

function cost(p)
    loss = 0
    for i = 1:Ns
        prob = ODEProblem(f_ode, u0s[:, i], tspan, p)
        pred = Array(solve(prob, Tsit5(); sensealg=sensealg, saveat=saveat))[:]
        loss += sum(abs2, pred - sols[:, i]) / Nt
    end
    return loss
end


# Train Loop
opt = Optimisers.ADAM(1e-2)
st_opt = Optimisers.setup(opt, ps)
numiters = Int(1e3)

for iter = 1:numiters
    println(iter)
    @time begin
    l, back = pullback(cost, ps)
    gs = back(one(l))[1]
    st_opt, ps = Optimisers.update(st_opt, ps, gs)
    end
end

Your f_ode is defined incorrectly. It takes the signature of a in place ODE function but it doesn’t update du

function f_ode(du, u, p, t)
    du_, _ = f(u, p, st)
    du .= du_
end

Thank you, yes it was as simple as that!

I had another follow-up question concerning the speed of training. In the simplest case when I have one trajectory (Ns = 1) in my dataset, if my neural network architecture consists of 3 hidden layers with 100 hidden nodes each, every iteration of ADAM takes approximately ~30 seconds making training infeasible.

Is this the state of the art for training neural ODEs or is there something I can do to speed up the gradient computation in each iteration. I’ve tried several sensitivity algorithms for computing the gradient and found ForwardDiffSensitivity() to be by far the fastest (~30 seconds per gradient computation).

Here is the updated code I am running:

cd(@__DIR__)

using Pkg
Pkg.activate(".")

using Lux
using ComponentArrays
using Zygote
using ForwardDiff
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Random
using CairoMakie

# Simulate Data
function true_ode(du,u, p, t)
    x, y = u
    du[1] = -y
    du[2] = x
end

# number of simulated trajectories
Ns = 1
u0s = randn(2, Ns)

tspan = (0.0, 10.0)
Nt = 100
saveat = LinRange(tspan[1], tspan[2], Nt)

sols = zeros(2*Nt, Ns)
for i = 1:Ns
    println(i)
    prob = ODEProblem(true_ode, u0s[:, i], tspan)
    sols[:, i] = Array(solve(prob, Tsit5(); saveat=saveat))[:]
end


# Define Neural ODE
hidden_nodes = 100
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, f)
ps = ComponentArray(ps)

function f_ode(du, u, p, t)
    du_, _ = f(u, p, st)
    du .= du_
end

# Define Neural ODE Objective Function
sensealg = ForwardDiffSensitivity()
#sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP())
#sensealg = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())

function cost(p)
    loss = 0
    for i = 1:Ns
        prob = ODEProblem(f_ode, u0s[:, i], tspan, p)
        pred = Array(solve(prob, Tsit5(); sensealg=sensealg, saveat=saveat))[:]
        loss += sum(abs2, pred - sols[:, i]) / Nt
    end
    return loss
end


# Train Loop
opt = Optimisers.ADAM(1e-2)
st_opt = Optimisers.setup(opt, ps)
numiters = Int(1e3)

for iter = 1:numiters
    println(iter)
    @time begin
    l, back = pullback(cost, ps)
    gs = back(one(l))[1]
    st_opt, ps = Optimisers.update(st_opt, ps, gs)
    end
end

Or use an out of place f if you’re not mutating at all.

Note there’s no reason to loop like that: just use u0 as a matrix and solve it all at once.

SimpleChains is a good idea for performance here if you want to go hardcore.

Thanks Chris, I actually tried to pass all the u0s once as a matrix instead of looping but its about 3 times slower on each iteration. Here is the code I use, where I’ve commented out the part that passes all the u0s in as one matrix.

cd(@__DIR__)

using Pkg
Pkg.activate(".")

using Lux
using ComponentArrays
using Zygote
using ForwardDiff
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Random
using CairoMakie

# Simulate Data
function true_ode(du,u, p, t)
    x, y = u
    du[1] = -y
    du[2] = x
end

# number of simulated trajectories
Ns = 100
u0s = randn(2, Ns)

tspan = (0.0, 10.0)
Nt = 100
saveat = LinRange(tspan[1], tspan[2], Nt)

sols = zeros(2, Nt, Ns)
for i = 1:Ns
    println(i)
    prob = ODEProblem(true_ode, u0s[:, i], tspan)
    sols[:, :, i] = Array(solve(prob, Tsit5(); saveat=saveat))
end


# Define Neural ODE
hidden_nodes = 10
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, f)
ps = ComponentArray(ps)

function f_ode(du, u, p, t)
    du_, _ = f(u, p, st)
    du .= du_
end

# Define Neural ODE Objective Function
sensealg = ForwardDiffSensitivity()
#sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP())
#sensealg = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())

function cost(p)
    loss = 0
    
    for i = 1:Ns
        prob = ODEProblem(f_ode, u0s[:, i], tspan, p)
        pred = Array(solve(prob, Tsit5(); sensealg=sensealg, saveat=saveat))[:]
        loss += sum(abs2, pred - sols[:, :, i][:]) / Nt
    end
    
    #prob = ODEProblem(f_ode, u0s, tspan, p)
    #pred = Array(solve(prob, Tsit5(); sensealg=sensealg, saveat=saveat))
    #loss = sum(abs2, pred - sols) / Nt
    return loss
end


# Train Loop
opt = Optimisers.ADAM(1e-2)
st_opt = Optimisers.setup(opt, ps)
numiters = Int(1e3)

for iter = 1:numiters
    println(iter)
    @time begin
    l, back = pullback(cost, ps)
    gs = back(one(l))[1]
    st_opt, ps = Optimisers.update(st_opt, ps, gs)
    end
end

For the other comments, I’m confused which f (my neural network?) I need to define as out-of-place.

For SimpleChains, I’m assuming this library does not interface with Lux? For what neural network sizes (number of hidden nodes) does SimpleChains offer a performance boost over Lux?

You should remove the inplace f_ode with an out of place version. The only thing it is doing here is adding copyto!.


# ForwardDiffSensitviity (Inplace): Unbatched
# 1
#  38.014920 seconds (97.33 M allocations: 6.284 GiB, 8.62% gc time)
# 2
#   0.948341 seconds (4.65 M allocations: 823.119 MiB, 18.59% gc time)
# 3
#   0.927106 seconds (4.75 M allocations: 848.865 MiB, 13.40% gc time)
# 4
#   0.937500 seconds (4.85 M allocations: 874.719 MiB, 13.01% gc time)
# 5
#   1.004452 seconds (5.00 M allocations: 913.194 MiB, 12.93% gc time)
# 6
#   1.064163 seconds (5.08 M allocations: 935.348 MiB, 13.23% gc time)
# 7
#   1.185658 seconds (5.17 M allocations: 956.666 MiB, 14.75% gc time)
# 8
#   1.208982 seconds (5.28 M allocations: 986.817 MiB, 13.02% gc time)
# 9
#   1.267006 seconds (5.53 M allocations: 1.025 GiB, 13.76% gc time)
# 10
#   1.405837 seconds (5.89 M allocations: 1.117 GiB, 13.01% gc time)
# 11
#   1.477285 seconds (6.28 M allocations: 1.216 GiB, 12.19% gc time)


# ForwardDiffSensitviity (Inplace): Batched
# 1
#  38.956579 seconds (71.47 M allocations: 12.208 GiB, 10.67% gc time)
# 2
#   7.361789 seconds (518.71 k allocations: 7.806 GiB, 16.87% gc time)
# 3
#   6.977775 seconds (498.31 k allocations: 7.497 GiB, 17.09% gc time)
# 4
#   7.061502 seconds (495.39 k allocations: 7.518 GiB, 18.33% gc time)
# 5
#   7.259349 seconds (500.27 k allocations: 7.612 GiB, 19.34% gc time)
# 6
#   7.514519 seconds (516.56 k allocations: 7.925 GiB, 20.34% gc time)


# Interpolating Adjoint (Out of Place): Batched
# 1
#  34.839393 seconds (86.51 M allocations: 5.545 GiB, 14.10% gc time)
# 2
#   0.606342 seconds (895.21 k allocations: 529.560 MiB, 46.80% gc time)
# 3
#   0.555090 seconds (893.89 k allocations: 531.121 MiB, 42.04% gc time)
# 4
#   0.552323 seconds (922.64 k allocations: 549.095 MiB, 39.57% gc time)
# 5
#   0.549601 seconds (897.84 k allocations: 537.016 MiB, 40.48% gc time)
# 6
#   0.612896 seconds (987.73 k allocations: 590.683 MiB, 41.25% gc time)
# 7
#   0.717766 seconds (1.17 M allocations: 695.469 MiB, 41.17% gc time)
# 8
#   0.772804 seconds (1.25 M allocations: 740.149 MiB, 40.79% gc time)
# 9
#   0.778109 seconds (1.26 M allocations: 752.737 MiB, 40.50% gc time)
# 10
#   0.839175 seconds (1.32 M allocations: 786.804 MiB, 40.64% gc time)

# Interpolating Adjoint (Out of Place): Batched -- No Globals
# 1
#  24.049004 seconds (56.05 M allocations: 3.384 GiB, 13.12% gc time)
# 2
#   0.309593 seconds (366.78 k allocations: 217.853 MiB, 54.62% gc time)
# 3
#   0.231718 seconds (346.04 k allocations: 205.799 MiB, 42.64% gc time)
# 4
#   0.204137 seconds (334.36 k allocations: 198.529 MiB, 38.77% gc time)
# 5
#   0.238633 seconds (373.85 k allocations: 222.444 MiB, 41.32% gc time)
# 6
#   0.305244 seconds (443.13 k allocations: 263.200 MiB, 45.48% gc time)
# 7
#   0.307950 seconds (498.75 k allocations: 296.492 MiB, 38.88% gc time)
# 8
#   0.378649 seconds (572.80 k allocations: 340.499 MiB, 42.38% gc time)
# 9
#   0.392125 seconds (612.13 k allocations: 363.842 MiB, 40.77% gc time)
# 10
#   0.477306 seconds (713.99 k allocations: 424.495 MiB, 42.52% gc time)
# 11
#   0.433254 seconds (658.85 k allocations: 392.922 MiB, 41.61% gc time)
# 12
#   0.439009 seconds (670.53 k allocations: 400.192 MiB, 41.07% gc time)
# 13
#   0.392162 seconds (599.43 k allocations: 359.815 MiB, 40.62% gc time)
# 14
#   0.416694 seconds (592.84 k allocations: 356.943 MiB, 44.17% gc time)
# 15
#   0.392114 seconds (588.40 k allocations: 354.837 MiB, 41.71% gc time)
# 16
  • Use Interpolating for batched problems (they are a bit too big for ForwardDiff to handle efficiently).
  • Finally, don’t use globals. It is making the entire pipelines type unstable (see the gains by just putting all the code inside a function)

Thank you Avik this is super helpful!

Just another quick question. What exactly is the portion of my code that I should put inside a function? Why does moving from global to local scope make the program more efficient?

Lastly, I was wondering if there is an efficient way of solving an ode as in the MRE above with a batch of different parameter values. For example, I don’t think I can pass a list of 100 choices of parameters into an ODE solve as I can with initial conditions. Is there an efficient way to do something like this without looping?

I’m assuming that in the case of batched initial conditions, instead of doing a naive for loop, the Julia solve function notices that there are 100 initial conditions and enlarges the ODE system to 100 times its original size which allows it to perform one solve with a large state vector to simultaneously simulate the ODE form all the initial conditions. Is this correct?

Also, now with your comments included, the updated MRE looks like

cd(@__DIR__)

using Pkg
Pkg.activate(".")

using Lux
using ComponentArrays
using Zygote
using ForwardDiff
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Random
using CairoMakie
using SimpleChains

# Simulate Data
function true_ode(du,u, p, t)
    x, y = u
    du[1] = -y
    du[2] = x
end

# number of simulated trajectories
Ns = 100
u0s = randn(2, Ns)

tspan = (0.0, 10.0)
Nt = 100
saveat = LinRange(tspan[1], tspan[2], Nt)

prob = ODEProblem(true_ode, u0s, tspan)
sols = Array(solve(prob, Tsit5(); saveat=saveat))


# Define Neural ODE
hidden_nodes = 10
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(2, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, f)
ps = ComponentArray(ps)

function f_ode(u, p, t)
    du, _ = f(u, p, st)
    return du
end

# Define Neural ODE Objective Function
#sensealg = ForwardDiffSensitivity()
#sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP())
#sensealg = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())

function cost(p, f_ode, u0s, tspan, saveat, sensealg, sols)
    Nt = size(sols, 2)
    prob = ODEProblem(ODEFunction(f_ode), u0s, tspan, p)
    pred = Array(solve(prob, Tsit5(); sensealg=sensealg, saveat=saveat))
    loss = sum(abs2, pred - sols) / Nt
    return loss
end


# Train Loop
opt = Optimisers.ADAM(1e-2)
st_opt = Optimisers.setup(opt, ps)
numiters = Int(1e3)

for iter = 1:numiters
    println(iter)
    @time begin
    l, back = pullback(p -> cost(p, f_ode, u0s, tspan, saveat, sensealg, sols), ps)
    gs = back(one(l))[1]
    st_opt, ps = Optimisers.update(st_opt, ps, gs)
    end
end

which takes about 0.15 seconds per gradient step!