Hello!
I am wondering about the difference when using ODEProblem(....)
from DifferentialEquations.jl and NeuralODE(....)
from DiffEqFlux.jl in terms of time performance.
Consider the code below, thus, definining the RHS explicitly and solving the neural ODE by using ODEProblem(....)
:
dudt2 = Lux.Chain(Lux.Dense(6, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 6))
function rhs!(du, u, p, t)
û = dudt2(u, p, st)[1]
du[1] = û[1]
du[2] = û[2]
du[3] = û[3]
du[4] = û[4]
du[5] = û[5]
du[6] = û[6]
end
function predict_neuralode(θ,st,dudt2,tspan,tsteps,u0)
prob_neuralode = ODEProblem(rhs!, u0, tspan)
_prob = remake(prob_neuralode, p = θ)
Array(solve(_prob, saveat = tsteps))
end
Now consider the code below, thus, solving the neural ODE by using NeuralODE(....)
from Diff
dudt2 = Lux.Chain(Lux.Dense(6, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 6))
function predict_neuralode(p,st,dudt2,tspan,tsteps,u0)
prob_neuralode = NeuralODE(dudt2, tspan, saveat = tsteps)
return Array(prob_neuralode(u0, p, st)[1])
end
For my specific problem, the simulation time is 3 times slower when using ODEProblem(....)
and solving the neural ODE than when using NeuralODE(....)
.
What is the reason for it being much slower? And is there a way to fix the significantly weaker time performance?