Gradient Computation Time of Neural ODE

I am training a neural ODE that simulates a neural ODE dy/dt = g(y) for Ns trajectories where each trajectory consists of Nt timepoints (y(t1), …, y(t_Nt)). As the number of trajectory samples Ns and the time points saved Nt get large, each gradient step begins to take tens or hundreds of seconds. When I run it with Ns = 50 simulations and Nt = 5000 timepoints I get a single instance time:

34.804460 seconds (17.37 M allocations: 85.357 GiB, 35.31% gc time)

Here is my entire script below:

cd(@__DIR__)

using Pkg
Pkg.activate(".")
Pkg.instantiate()

using Lux
using ComponentArrays,
    SciMLSensitivity,
    Optimisers,
    OrdinaryDiffEq,
    Random,
    Statistics,
    Zygote,
    LossFunctions,
    MAT,
    BenchmarkTools

struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa, K} <:
    Lux.AbstractExplicitContainerLayer{(:model,)}
    model::M
    solver::So
    sensealg::Se
    tspan::T
    saveat::Sa
    kwargs::K
end

function NeuralODE(model::Lux.AbstractExplicitLayer;
    solver=Tsit5(),
    sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
    tspan=(0.0f0, 1.0f0),
    saveat=[],
    kwargs...)
    return NeuralODE(model, solver, sensealg, tspan, saveat, kwargs)
end


function (n::NeuralODE)(u0, ps, st)
    function dudt(u, p, t)
        return n.model(u, p, st)[1]
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), u0, n.tspan, ps)
    return Array(solve(prob, n.solver; saveat=n.saveat, sensealg=n.sensealg, n.kwargs...)), st
end

function create_model(tspan, saveat)
    Nt = length(saveat)

    #sensealg = ForwardDiffSensitivity()
    sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    #sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
    #sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
    #sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(true))

    # Construct the Neural ODE Model
    model = NeuralODE(Chain(Dense(1, 100, selu), Dense(100, 100, selu), Dense(100, 100, selu), Dense(100, 100, selu), Dense(100, 1, selu));
                sensealg=sensealg,
                tspan=tspan,
                saveat=saveat,
                reltol=1.0f-3,
                abstol=1.0f-3)

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    ps = ComponentArray(ps)

    return model, ps, st
end

function loss(Ns, model, ps, st)
    y_pred, st = model(ones(Float32, 1, Ns), ps, st)
    return mean(abs2.(y_pred)), st
end

function timecode(Ns, Nt)
    println("NeuralODE Simulations: $Ns")
    println("NeuralODE Simulation Timepoints: $Nt")

    tspan = (0.0f0, 1.0f0)
    saveat = LinRange(tspan[1], tspan[2], Nt)
    model, ps, st = create_model(tspan, saveat)

    loss(Ns, model, ps, st)
    (l, _), back = pullback(p -> loss(Ns, model, p, st), ps)

    back((one(l), nothing))
    @time back((one(l), nothing))
    return
end

timecode(50, 5000)

Based on the email discussion, if you replace Tsit5 with Euler, you should get a backpropagation time of ~4s.

What’s the current main cost here? 100x100 matvecs? What about when changing to MKL?

Something worth trying might be using an RNN – there are 5000 points between 0-1 (that should give us the rough cost from matvecs vs the ODE solve)

Yeah with 5000 points the required dt is rather small in the adjoint, which means you don’t really gain that much by going to higher order. Indeed low order does make sense here.

@avikpal @ChrisRackauckas thank you for your advice! Just to make sure I understand, are you saying I can try to replace the neural ODE with simply a finite depth (layer) recurrent neural network?

Still confused about the statement “you don’t really gain much by going to higher order”? Does this mean that going from an RNN to a first order ODE will not give a significant improvement in the fit? How is this related to the step size of the adjoint equation? Thanks in advance!