ODEProblem(....) vs NeuralODE(....) for neural ODEs

Hello!

I am wondering about the difference when using ODEProblem(....) from DifferentialEquations.jl and NeuralODE(....) from DiffEqFlux.jl in terms of time performance.

Consider the code below, thus, definining the RHS explicitly and solving the neural ODE by using ODEProblem(....):

dudt2 = Lux.Chain(Lux.Dense(6, 8, swish), 
Lux.Dense(8, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 6))

function rhs!(du, u, p, t)

    û = dudt2(u, p, st)[1]
    du[1] = û[1]
    du[2] = û[2]
    du[3] = û[3]
    du[4] = û[4]
    du[5] = û[5]
    du[6] = û[6]

end


function predict_neuralode(θ,st,dudt2,tspan,tsteps,u0)
    prob_neuralode = ODEProblem(rhs!, u0, tspan)
    _prob = remake(prob_neuralode, p = θ)
    Array(solve(_prob, saveat = tsteps)) 
end

Now consider the code below, thus, solving the neural ODE by using NeuralODE(....) from Diff

dudt2 = Lux.Chain(Lux.Dense(6, 8, swish), 
Lux.Dense(8, 8, swish),
Lux.Dense(8, 8, swish),
Lux.Dense(8, 6))


function predict_neuralode(p,st,dudt2,tspan,tsteps,u0)
    prob_neuralode = NeuralODE(dudt2, tspan, saveat = tsteps)
    return Array(prob_neuralode(u0, p, st)[1])
end

For my specific problem, the simulation time is 3 times slower when using ODEProblem(....) and solving the neural ODE than when using NeuralODE(....).

What is the reason for it being much slower? And is there a way to fix the significantly weaker time performance?

Is it the same solver and options? NeuralODE sets a few defaults that make sense for neural ODEs and optimizes a few things based on how it’s normally used. Check the solve results.

Thanks! So I studied the output of solve(....) (for the second case where I don’t define a RHS function I studied the output of NeuralODE(....)). The only difference that I found was that for interp, the cache is different as I have shown in the figure below. The variable in the top, pred_neuralnew is the output from NeuralODE and pred_ode is the output from solve for the ODEProblem. Is there a way to change it so that the cache is the same for both cases?

dense=false. You cannot adjoint the interpolation so it must set it to false. I can look at the code later and see.

I’ve created and uploaded a toy-example for a model of multiple chemical reactions taking place.
In this specific case, defining the RHS explicitly and using ODEProblem(....) and solve(....) together with dense = false have similar computation times compared to using only NeuralODE(....) however there are still some differences in terms of computation time.

I am also working on a larger scale version of this and the computation time is significantly worse when using ODEProblem(....) and solve(....) compared to only NeuralODE(....). I have also noticed that the activation function has a huge effect on the similarity of the computation times. For instance when using tanh(), the computation times are more similar rather than using relu-like activation functions such as swish().
I am wondering how else one can modify ODEProblem(....) and solve(....) so that it is equivalent to NeuralODE(....) besides using dense = false in solve.

test_node_vs_ode.jl (2.7 KB)

It’s just out of place and ZygoteVJP:

Did you try and out of place definition?

function rhs!(u, p, t)
  dudt2(u, p, st)[1]
end
1 Like

So I have tried to benchmark the code below vs just using NeuralODE(....) and when using Adam the computation time is similar. However when using Adam and switching to BFGS when close to the minima, the code below is significantly faster compared to NeuralODE(.....). Do you know why? And how I can modify the code below so that the computation time is the same as when using NeuralODE(....) together with Adam + BFGS?



function rhs!(du, u, p, t)
    
    û = dudt2(u, p, st)[1]
    du[1] = û[1] 
    du[2] = û[2] 
    du[3] = û[3]
    du[4] = û[4] 
    du[5] = û[5] 
    du[6] = û[6] 

end

basic_tgrad(u,p,t) = zeros(GT_data)

function predict_neuralode(θ,st,dudt2,tspan,tsteps,u0)
    ff = ODEFunction{false}(rhs!; tgrad = basic_tgrad)
    prob = ODEProblem{false}(ff, u0, tspan)
    _prob = remake(prob, p = θ)
    Array(solve(_prob, Vern7(), saveat = tsteps, sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP())))
end

I would be surprised if the code below runs, since rhs!(du, u, p, t) plus ODEFunction{false}(rhs!; tgrad = basic_tgrad) is contradictory: the false directly implies it’s only looking for a dispatch rhs(u, p, t) which doesn’t exist.

Okay that’s interesting because it actually did run and even converge. Let me try removing the {false} and try benchmarking that.

Run it in a new REPL, you’ll see that what you have there requires the function I defined above.

So when I remove {false} from ODEFunction(...) I actually get the following error message:


ERROR: Nonconforming functions detected. If a model function `f` is defined
as in-place, then all constituent functions like `jac` and `paramjac`
must be in-place (and vice versa with out-of-place). Detected that
some overloads did not conform to the same convention as `f`.

Nonconforming functions: ["tgrad"]

However when using {false} together with rhs!(du, u, p, t), as I had it before it works well.

Yes, that’s what I said. The out of place definition:

function rhs!(u, p, t)
  dudt2(u, p, st)[1]
end

is required for the false version (that’s what it means), and that’s what’s faster for Zygote reverse mode. It should be faster for Adam and BFGS for this use case. I think the code you’re testing with was just mixing this up.

Thanks for the clarification. But what if I want to use the function:

function rhs!(du, u, p, t)
    
    û = dudt2(u, p, st)[1]
    du[1] = û[1] 
    du[2] = û[2] 
    du[3] = û[3]
    du[4] = û[4] 
    du[5] = û[5] 
    du[6] = û[6] 

end

This also only works together with {false}.

No, that’s the in-place function. It only works with {true}, which is default preferred. What I’m saying is you probably don’t want to do in-place with neural networks: that’s what currently isn’t optimized in reverse mode.

Alright thanks!