I am training a neural ODE that simulates a neural ODE dy/dt = g(y) for Ns trajectories where each trajectory consists of Nt timepoints (y(t1), …, y(t_Nt)). As the number of trajectory samples Ns and the time points saved Nt get large, each gradient step begins to take tens or hundreds of seconds. When I run it with Ns = 50 simulations and Nt = 5000 timepoints I get a single instance time:
34.804460 seconds (17.37 M allocations: 85.357 GiB, 35.31% gc time)
Here is my entire script below:
cd(@__DIR__)
using Pkg
Pkg.activate(".")
Pkg.instantiate()
using Lux
using ComponentArrays,
SciMLSensitivity,
Optimisers,
OrdinaryDiffEq,
Random,
Statistics,
Zygote,
LossFunctions,
MAT,
BenchmarkTools
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa, K} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
solver::So
sensealg::Se
tspan::T
saveat::Sa
kwargs::K
end
function NeuralODE(model::Lux.AbstractExplicitLayer;
solver=Tsit5(),
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
tspan=(0.0f0, 1.0f0),
saveat=[],
kwargs...)
return NeuralODE(model, solver, sensealg, tspan, saveat, kwargs)
end
function (n::NeuralODE)(u0, ps, st)
function dudt(u, p, t)
return n.model(u, p, st)[1]
end
prob = ODEProblem{false}(ODEFunction{false}(dudt), u0, n.tspan, ps)
return Array(solve(prob, n.solver; saveat=n.saveat, sensealg=n.sensealg, n.kwargs...)), st
end
function create_model(tspan, saveat)
Nt = length(saveat)
#sensealg = ForwardDiffSensitivity()
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
#sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(true))
# Construct the Neural ODE Model
model = NeuralODE(Chain(Dense(1, 100, selu), Dense(100, 100, selu), Dense(100, 100, selu), Dense(100, 100, selu), Dense(100, 1, selu));
sensealg=sensealg,
tspan=tspan,
saveat=saveat,
reltol=1.0f-3,
abstol=1.0f-3)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)
return model, ps, st
end
function loss(Ns, model, ps, st)
y_pred, st = model(ones(Float32, 1, Ns), ps, st)
return mean(abs2.(y_pred)), st
end
function timecode(Ns, Nt)
println("NeuralODE Simulations: $Ns")
println("NeuralODE Simulation Timepoints: $Nt")
tspan = (0.0f0, 1.0f0)
saveat = LinRange(tspan[1], tspan[2], Nt)
model, ps, st = create_model(tspan, saveat)
loss(Ns, model, ps, st)
(l, _), back = pullback(p -> loss(Ns, model, p, st), ps)
back((one(l), nothing))
@time back((one(l), nothing))
return
end
timecode(50, 5000)