Performance when using DifferentialEquations.solve repeatedly inside a loop

I am using the split step method to solve the following equation:

i \frac{\partial \psi}{\partial t} = -\frac{1}{2} \frac{\partial^2 \psi}{\partial x^2} - |\psi|^2\psi + i \alpha\left(\frac{\partial^3 \psi}{\partial x^3} + 6 |\psi|^2 \frac{\partial \psi}{\partial x} \right)

The first 3 terms are straightforward when using the split step method. For the fourth term, I basically need to solve this equation:

\frac{\partial \psi}{\partial t} = \alpha\left(6 |\psi|^2 \frac{\partial \psi}{\partial x} \right)

after evolving the equation for the first three terms. i.e. I need to repeatedly use DifferentialEquations.solve inside an outer time loop to evolve the solution a single step. This is very slow, unfortunately.

This is a simple example of how this works (I can provide a full MWE if needed but it’s on the longer side)

function solve!(...)
# Prepare some stuff
# ...

# Differential Operators
α = 0.1
x_order = 2
D = CenteredDifference(1, x_order, dx, Nx)
Q = PeriodicBC(Float64)

function M!(du, u,p,t)
du = 6*α*D*Q*u.*abs2.(u)
end

for i = 1:Nt-1
@views ψ[:, i+1] .= T(ψ[:, i], dt, M!, ...)
end

return nothing
end #solve

function T(ψ, W, dt, F, F̃, M!)
# Perform some operations in-place on ψ

# Calling this every loop is slow and allocates a lot
prob = ODEProblem(M!, ψ, (0.0, dt))
sol = solve(prob, BS3(), save_everystep = false, dt = dt)
ψ = sol.u[end]

return ψ
end


So as you can see, the function T evolves time by a single step, so it performs some operations on ψ then has to solve the ODE. This happens every single time step, so re-initializing the problem and then solving it makes the whole process very time consuming. I made the function M! in place and it helped a bit, but I am stuck at this point.

du .= 6*α*D*Q*u.*abs2.(u). You’re creating a new vector and not mutating.

Don’t recreate new solves every dt: why not use a DiscreteCallback or the integrator interface here? You’re recreating the cache every time which is what’s hurting.

This is a quick answer because I’m busy but hopefully it points in the right direction and can be cleared up in some follow-up. I just realized that the integrator interface documentation would be much better with a quick example. https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/test/integrators/iterator_tests.jl could be used to write such an example.

2 Likes

Whoops, not the first time I’ve made this mistake, thank you!

That definitely got me in the right direction, thank you Chris!

I ended up with:

function solve!(...)
# Prepare some stuff
# ...

# Differential Operators
α = 0.1
x_order = 2
D = CenteredDifference(1, x_order, dx, Nx)
Q = PeriodicBC(Float64)
function M!(du, u,p,t)
du .= 6*α*D*Q*u.*abs2.(u)
end
prob = ODEProblem(M!, ψ[:, 1], (0, dt))
integrator = init(prob, BS3(); dt=dt,save_everystep=false)

for i = 1:Nt-1
@views ψ[:, i+1] .= T(ψ[:, i], dt, M!, ...)
end

return nothing
end #solve

function T(ψ, dt, integrator, ...)
# Perform some operations in-place on ψ

# Now this is much quicker
set_u!(integrator, ψ)
step!(integrator)
ψ = integrator.u

return ψ
end


This has improved the performance by over four times! I will take a look at call backs to replace set_u!, and keep trying to further optimize this. Thanks again.

1 Like