# ODEProblem(....) vs NeuralODE(....) for neural ODEs

Hello!

I am wondering about the difference when using `ODEProblem(....)` from DifferentialEquations.jl and ` NeuralODE(....)` from DiffEqFlux.jl in terms of time performance.

Consider the code below, thus, definining the RHS explicitly and solving the neural ODE by using `ODEProblem(....)`:

``````dudt2 = Lux.Chain(Lux.Dense(6, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 6))

function rhs!(du, u, p, t)

û = dudt2(u, p, st)[1]
du[1] = û[1]
du[2] = û[2]
du[3] = û[3]
du[4] = û[4]
du[5] = û[5]
du[6] = û[6]

end

function predict_neuralode(θ,st,dudt2,tspan,tsteps,u0)
prob_neuralode = ODEProblem(rhs!, u0, tspan)
_prob = remake(prob_neuralode, p = θ)
Array(solve(_prob, saveat = tsteps))
end
``````

Now consider the code below, thus, solving the neural ODE by using `NeuralODE(....)` from Diff

``````dudt2 = Lux.Chain(Lux.Dense(6, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 6))

function predict_neuralode(p,st,dudt2,tspan,tsteps,u0)
prob_neuralode = NeuralODE(dudt2, tspan, saveat = tsteps)
return Array(prob_neuralode(u0, p, st)[1])
end
``````

For my specific problem, the simulation time is 3 times slower when using `ODEProblem(....)` and solving the neural ODE than when using ` NeuralODE(....)`.

What is the reason for it being much slower? And is there a way to fix the significantly weaker time performance?

Is it the same solver and options? NeuralODE sets a few defaults that make sense for neural ODEs and optimizes a few things based on how it’s normally used. Check the `solve` results.

Thanks! So I studied the output of `solve(....)` (for the second case where I don’t define a RHS function I studied the output of `NeuralODE(....)`). The only difference that I found was that for `interp`, the `cache` is different as I have shown in the figure below. The variable in the top, `pred_neuralnew` is the output from `NeuralODE` and `pred_ode` is the output from `solve` for the `ODEProblem`. Is there a way to change it so that the `cache` is the same for both cases?

`dense=false`. You cannot adjoint the interpolation so it must set it to false. I can look at the code later and see.

I’ve created and uploaded a toy-example for a model of multiple chemical reactions taking place.
In this specific case, defining the RHS explicitly and using `ODEProblem(....)` and `solve(....)` together with `dense = false` have similar computation times compared to using only `NeuralODE(....)` however there are still some differences in terms of computation time.

I am also working on a larger scale version of this and the computation time is significantly worse when using `ODEProblem(....)` and `solve(....)` compared to only `NeuralODE(....)`. I have also noticed that the activation function has a huge effect on the similarity of the computation times. For instance when using `tanh()`, the computation times are more similar rather than using relu-like activation functions such as `swish()`.
I am wondering how else one can modify `ODEProblem(....)` and `solve(....)` so that it is equivalent to `NeuralODE(....)` besides using `dense = false` in `solve`.

test_node_vs_ode.jl (2.7 KB)

It’s just out of place and ZygoteVJP:

Did you try and out of place definition?

``````function rhs!(u, p, t)
dudt2(u, p, st)[1]
end
``````
1 Like

So I have tried to benchmark the code below vs just using `NeuralODE(....)` and when using Adam the computation time is similar. However when using Adam and switching to BFGS when close to the minima, the code below is significantly faster compared to `NeuralODE(.....)`. Do you know why? And how I can modify the code below so that the computation time is the same as when using `NeuralODE(....)` together with Adam + BFGS?

``````

function rhs!(du, u, p, t)

û = dudt2(u, p, st)[1]
du[1] = û[1]
du[2] = û[2]
du[3] = û[3]
du[4] = û[4]
du[5] = û[5]
du[6] = û[6]

end

function predict_neuralode(θ,st,dudt2,tspan,tsteps,u0)
prob = ODEProblem{false}(ff, u0, tspan)
_prob = remake(prob, p = θ)
Array(solve(_prob, Vern7(), saveat = tsteps, sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP())))
end
``````

I would be surprised if the code below runs, since `rhs!(du, u, p, t)` plus `ODEFunction{false}(rhs!; tgrad = basic_tgrad)` is contradictory: the `false` directly implies it’s only looking for a dispatch `rhs(u, p, t)` which doesn’t exist.

Okay that’s interesting because it actually did run and even converge. Let me try removing the `{false}` and try benchmarking that.

Run it in a new REPL, you’ll see that what you have there requires the function I defined above.

So when I remove `{false}` from `ODEFunction(...)` I actually get the following error message:

``````
ERROR: Nonconforming functions detected. If a model function `f` is defined
as in-place, then all constituent functions like `jac` and `paramjac`
must be in-place (and vice versa with out-of-place). Detected that
some overloads did not conform to the same convention as `f`.

``````

However when using `{false}` together with `rhs!(du, u, p, t)`, as I had it before it works well.

Yes, that’s what I said. The out of place definition:

``````function rhs!(u, p, t)
dudt2(u, p, st)[1]
end
``````

is required for the `false` version (that’s what it means), and that’s what’s faster for Zygote reverse mode. It should be faster for Adam and BFGS for this use case. I think the code you’re testing with was just mixing this up.

Thanks for the clarification. But what if I want to use the function:

``````function rhs!(du, u, p, t)

û = dudt2(u, p, st)[1]
du[1] = û[1]
du[2] = û[2]
du[3] = û[3]
du[4] = û[4]
du[5] = û[5]
du[6] = û[6]

end
``````

This also only works together with {false}.

No, that’s the in-place function. It only works with `{true}`, which is default preferred. What I’m saying is you probably don’t want to do in-place with neural networks: that’s what currently isn’t optimized in reverse mode.

Alright thanks!