I am training a neural ODE that simulates a neural ODE dy/dt = g(y) for Ns trajectories where each trajectory consists of Nt timepoints (y(t1), …, y(t_Nt)). As the number of trajectory samples Ns and the time points saved Nt get large, each gradient step begins to take tens or hundreds of seconds. When I run it with Ns = 50 simulations and Nt = 5000 timepoints I get a single instance time:
34.804460 seconds (17.37 M allocations: 85.357 GiB, 35.31% gc time)
Here is my entire script below:
cd(@__DIR__)
using Pkg
Pkg.activate(".")
Pkg.instantiate()
using Lux
using ComponentArrays,
    SciMLSensitivity,
    Optimisers,
    OrdinaryDiffEq,
    Random,
    Statistics,
    Zygote,
    LossFunctions,
    MAT,
    BenchmarkTools
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa, K} <:
    Lux.AbstractExplicitContainerLayer{(:model,)}
    model::M
    solver::So
    sensealg::Se
    tspan::T
    saveat::Sa
    kwargs::K
end
function NeuralODE(model::Lux.AbstractExplicitLayer;
    solver=Tsit5(),
    sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
    tspan=(0.0f0, 1.0f0),
    saveat=[],
    kwargs...)
    return NeuralODE(model, solver, sensealg, tspan, saveat, kwargs)
end
function (n::NeuralODE)(u0, ps, st)
    function dudt(u, p, t)
        return n.model(u, p, st)[1]
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), u0, n.tspan, ps)
    return Array(solve(prob, n.solver; saveat=n.saveat, sensealg=n.sensealg, n.kwargs...)), st
end
function create_model(tspan, saveat)
    Nt = length(saveat)
    #sensealg = ForwardDiffSensitivity()
    sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    #sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
    #sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
    #sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(true))
    # Construct the Neural ODE Model
    model = NeuralODE(Chain(Dense(1, 100, selu), Dense(100, 100, selu), Dense(100, 100, selu), Dense(100, 100, selu), Dense(100, 1, selu));
                sensealg=sensealg,
                tspan=tspan,
                saveat=saveat,
                reltol=1.0f-3,
                abstol=1.0f-3)
    rng = Random.default_rng()
    Random.seed!(rng, 0)
    ps, st = Lux.setup(rng, model)
    ps = ComponentArray(ps)
    return model, ps, st
end
function loss(Ns, model, ps, st)
    y_pred, st = model(ones(Float32, 1, Ns), ps, st)
    return mean(abs2.(y_pred)), st
end
function timecode(Ns, Nt)
    println("NeuralODE Simulations: $Ns")
    println("NeuralODE Simulation Timepoints: $Nt")
    tspan = (0.0f0, 1.0f0)
    saveat = LinRange(tspan[1], tspan[2], Nt)
    model, ps, st = create_model(tspan, saveat)
    loss(Ns, model, ps, st)
    (l, _), back = pullback(p -> loss(Ns, model, p, st), ps)
    back((one(l), nothing))
    @time back((one(l), nothing))
    return
end
timecode(50, 5000)