ForwardDiffSensitivity Faster than Adjoint Methods

I am having a problem speeding up my Neural ODE gradient computations. Namely, I have defined a neural ode struct which accepts a control input and am aiming to optimize the parameters of the neural network to optimize an objective. In the MRE posted below, I am finding that ForwardDiffSensitivity takes approximately 5 seconds to compute the gradient of a loss where I simulate my neural ODE with 100 different control inputs. Using any of the other adjoint sensitivity methods takes 6 seconds or more. Is my code not sufficiently optimized? Why are all the adjoint sensitivity methods performing worse than forward diff?

cd(@__DIR__)

using Pkg
Pkg.activate(".")

using Lux
using ComponentArrays
using Zygote
using SciMLSensitivity
using ForwardDiff
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Random
using CairoMakie
using SimpleChains
using Interpolations
using BenchmarkTools




# Neural ODE which accepts control as an argument
struct ControlledODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa} <:
    Lux.AbstractExplicitContainerLayer{(:model,)}
  model::M
  solver::So
  sensealg::Se
  tspan::T
  saveat::Sa
end

function ControlledODE(model::Lux.AbstractExplicitLayer;
  solver=Tsit5(),
  sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
  tspan=(0.0f0, 1.0f0),
  saveat=[])
  return ControlledODE(model, solver, sensealg, tspan, saveat)
end

function (n::ControlledODE)(u0, control, ps, st; tspan=n.tspan, saveat=n.saveat)
  function n_ode(u, p, t)
    c = Float32.([ctrl(t) for ctrl in control])
    du, _ = n.model([u; c], p, st)
    return du
  end
  prob = ODEProblem(ODEFunction(n_ode), u0, tspan, ps)
  return solve(prob, n.solver; sensealg=n.sensealg, saveat=saveat)
end


tspan = (0.0f0, 10.0f0)
Nt = 100
saveat = LinRange{Float32}(tspan[1], tspan[2], Nt)

# number of controlled trajectories
Nc = 100
rng = MersenneTwister(1111)
disc_controls = randn(rng, Nc, Nt)
controls = [linear_interpolation(saveat, disc_controls[i, :]) for i = 1:Nc]


# Define Neural ODE with controls applied in a for-loop
hidden_nodes = 10
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(3, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, f)
ps = ComponentArray(ps)

# Sensitivity algorithm for AD
sensealg = ForwardDiffSensitivity()
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
#sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(true))

model = ControlledODE(f; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)


function cost(ps, u0)
    loss = 0.0f0
    for i = 1:Nc
        pred = model(u0, [controls[i]], ps, st)
        loss += sum(abs2, pred) / Nt
    end
    return loss
end

u0 = zeros(2)
@benchmark begin
    l, back = pullback(p -> cost(p, u0), ps)
    gs = back(one(l))[1]
end

The cost of forward is dependent on the number of ODEs and parameters, not based on the number of control inputs. The general cutoff point is that forward will be faster if # ODEs + # Parameters < 100 or so.

My current neural network has upwards of 250 parameters so I assumed it was in the regime where adjoint methods should outperform forward diff.

I’ve also benchmarked the different sensitivity methods against each other just on one control input. Perhaps something in my code needs to be optimized? Would making sure my code is type-stable help here?

ForwardDiff:
BenchmarkTools.Trial: 121 samples with 1 evaluation.
Range (min … max): 38.006 ms … 92.124 ms ┊ GC (min … max): 10.94% … 32.93%
Time (median): 41.954 ms ┊ GC (median): 18.76%
Time (mean ± σ): 41.471 ms ± 6.143 ms ┊ GC (mean ± σ): 15.67% ± 4.33%
38 ms Histogram: log(frequency) by time 79.1 ms <
Memory estimate: 68.11 MiB, allocs estimate: 278199.

InterpolatingAdjoint:
BenchmarkTools.Trial: 96 samples with 1 evaluation.
Range (min … max): 50.839 ms … 58.022 ms ┊ GC (min … max): 22.61% … 30.17%
Time (median): 51.771 ms ┊ GC (median): 22.50%
Time (mean ± σ): 52.148 ms ± 1.418 ms ┊ GC (mean ± σ): 23.01% ± 1.74%
50.8 ms Histogram: frequency by time 57.9 ms <
Memory estimate: 89.78 MiB, allocs estimate: 342862.

BacksolveAdjoint:
BenchmarkTools.Trial: 78 samples with 1 evaluation.
Range (min … max): 60.529 ms … 69.281 ms ┊ GC (min … max): 18.74% … 26.87%
Time (median): 66.369 ms ┊ GC (median): 26.05%
Time (mean ± σ): 64.420 ms ± 3.058 ms ┊ GC (mean ± σ): 23.38% ± 3.57%
60.5 ms Histogram: frequency by time 68.3 ms <
Memory estimate: 109.66 MiB, allocs estimate: 417670.

QuadratureAdjoint:
BenchmarkTools.Trial: 33 samples with 1 evaluation.
Range (min … max): 152.325 ms … 163.793 ms ┊ GC (min … max): 23.75% … 27.29%
Time (median): 154.511 ms ┊ GC (median): 24.41%
Time (mean ± σ): 156.111 ms ± 3.123 ms ┊ GC (mean ± σ): 24.89% ± 1.27%
152 ms Histogram: frequency by time 164 ms <
Memory estimate: 273.74 MiB, allocs estimate: 998588.

Another issue, if I want to parallelize the looping over controls with @Threads.threads like

cd(@__DIR__)

using Pkg
Pkg.activate(".")

using Lux
using ComponentArrays
using Zygote
using SciMLSensitivity
using ForwardDiff
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Random
using CairoMakie
using SimpleChains
using Interpolations
using BenchmarkTools




# Neural ODE which accepts control as an argument
struct ControlledODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa} <:
    Lux.AbstractExplicitContainerLayer{(:model,)}
  model::M
  solver::So
  sensealg::Se
  tspan::T
  saveat::Sa
end

function ControlledODE(model::Lux.AbstractExplicitLayer;
  solver=Tsit5(),
  sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
  tspan=(0.0f0, 1.0f0),
  saveat=[])
  return ControlledODE(model, solver, sensealg, tspan, saveat)
end

function (n::ControlledODE)(u0, control, ps, st; tspan=n.tspan, saveat=n.saveat)
  function n_ode(u, p, t)
    c = Float32.([ctrl(t) for ctrl in control])
    du, _ = n.model([u; c], p, st)
    return du
  end
  prob = ODEProblem(ODEFunction(n_ode), u0, tspan, ps)
  return solve(prob, n.solver; sensealg=n.sensealg, saveat=saveat)
end


tspan = (0.0f0, 10.0f0)
Nt = 100
saveat = LinRange{Float32}(tspan[1], tspan[2], Nt)

# number of controlled trajectories
Nc = 1
rng = MersenneTwister(1111)
disc_controls = randn(rng, Nc, Nt)
controls = [linear_interpolation(saveat, disc_controls[i, :]) for i = 1:Nc]


# Define Neural ODE with controls applied in a for-loop
hidden_nodes = 10
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(3, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
                        Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))

rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, f)
ps = ComponentArray(ps)

# Sensitivity algorithm for AD
#sensealg = ForwardDiffSensitivity()
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
#sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(true))

model = ControlledODE(f; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)


function cost(ps, u0)
    loss = 0.0f0
    @Threads.threads for i = 1:Nc
        pred = model(u0, [controls[i]], ps, st)
        loss += sum(abs2, pred) / Nt
    end
    return loss
end

u0 = zeros(2)
@benchmark begin
    l, back = pullback(p -> cost(p, u0), ps)
    gs = back(one(l))[1]
end

I get the following error:

ERROR: Compiling Tuple{typeof(Base._wait), Task}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:101 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::typeof(Base._wait), args::Task)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:101
  [3] _pullback
    @ ./threadingconstructs.jl:115 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(Base.Threads.threading_run), ::var"#185#threadsfor_fun#37"{var"#185#threadsfor_fun#36#38"{ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}}, Vector{Float64}, UnitRange{Int64}}}, ::Bool)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
  [5] macro expansion
    @ ./threadingconstructs.jl:168 [inlined]
  [6] _pullback
    @ ~/Research/constitutive_history/optimization/batch_control_mre.jl:88 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(cost), ::ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/Research/constitutive_history/optimization/batch_control_mre.jl:97 [inlined]
  [9] _pullback(ctx::Zygote.Context{false}, f::var"#39#40", args::ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [10] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:44
 [11] pullback(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:42
 [12] macro expansion
    @ ~/Research/constitutive_history/optimization/batch_control_mre.jl:97 [inlined]
 [13] var"##core#1694"()
    @ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:489
 [14] var"##sample#1695"(::Tuple{}, __params::BenchmarkTools.Parameters)
    @ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:495
 [15] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:99
 [16] #invokelatest#2
    @ ./essentials.jl:818 [inlined]
 [17] invokelatest
    @ ./essentials.jl:813 [inlined]
 [18] #run_result#45
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
 [19] run_result
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
 [20] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117
 [21] run (repeats 2 times)
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117 [inlined]
 [22] #warmup#54
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:169 [inlined]
 [23] warmup(item::BenchmarkTools.Benchmark)
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:168
 [24] top-level scope
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:393

That always helps and is the starting point. If you do a flame graph of the forward pass with no differentiation, what are the hot spots?

@ChrisRackauckas from just a forward pass (no differentiation), I get a flamegraph which looks like this:

https://drive.google.com/drive/folders/1QQcPh8FajL_xNokrpdf_ND-Jm6Bt4iFm?usp=sharing

It doesn’t seem like there is one operation that is taking up all of the time.

Is this the right benchmark? It shows a NeuralODE call taking up most of the time, but you don’t have any in your code above. And I don’t see any interpolations from the controls, so it doesn’t look like it’s measuring the code you’re showing. Also no multithreading is shown.

Tthe selu is taking up about 50%. That’s a pretty expensive activation function in comparison to the other operations and is one key bottleneck here. Also the allocations from the matrix multiplications are rather non-trivial, so if doing something at this size SimpleChains would be a pretty major performance improvement (though that aspect is a pretty fixed cost in comparison to other things so if this is just an example code then don’t worry about it). Other than that it’s just the ODE saving, which is expected.

But again, I don’t think it’s the right profile so I wouldn’t read too much from that.

Reading your other post, I think you sent the flamegraph from this model Composing a Neural ODE with another Neural Network

@ChrisRackauckas apologies for the confusion. I realized that the control input was not causing the latency issue, so I simplified the model to just be a simple neural ODE that accepts an initial condition, no control.

I am however very concerned with the following post Composing a Neural ODE with another Neural Network
which seems to show a huge slowdown when I optimize a neural ODE composed with a neural network. This is the main problem I am trying to resolve.