Performance when using DifferentialEquations.solve repeatedly inside a loop

I am using the split step method to solve the following equation:

i \frac{\partial \psi}{\partial t} = -\frac{1}{2} \frac{\partial^2 \psi}{\partial x^2} - |\psi|^2\psi + i \alpha\left(\frac{\partial^3 \psi}{\partial x^3} + 6 |\psi|^2 \frac{\partial \psi}{\partial x} \right)

The first 3 terms are straightforward when using the split step method. For the fourth term, I basically need to solve this equation:

\frac{\partial \psi}{\partial t} = \alpha\left(6 |\psi|^2 \frac{\partial \psi}{\partial x} \right)

after evolving the equation for the first three terms. i.e. I need to repeatedly use DifferentialEquations.solve inside an outer time loop to evolve the solution a single step. This is very slow, unfortunately.

This is a simple example of how this works (I can provide a full MWE if needed but it’s on the longer side)

function solve!(...)
    # Prepare some stuff
    # ...

    # Differential Operators
    α = 0.1
    x_order = 2
    D = CenteredDifference(1, x_order, dx, Nx) 
    Q = PeriodicBC(Float64)

    function M!(du, u,p,t)
        du = 6*α*D*Q*u.*abs2.(u)
    end

    for i = 1:Nt-1
        @views ψ[:, i+1] .= T(ψ[:, i], dt, M!, ...)
    end

    return nothing
end #solve

function T(ψ, W, dt, F, F̃, M!)
    # Perform some operations in-place on ψ

    # Calling this every loop is slow and allocates a lot
    prob = ODEProblem(M!, ψ, (0.0, dt))
    sol = solve(prob, BS3(), save_everystep = false, dt = dt) 
    ψ = sol.u[end]

    return ψ
end

So as you can see, the function T evolves time by a single step, so it performs some operations on ψ then has to solve the ODE. This happens every single time step, so re-initializing the problem and then solving it makes the whole process very time consuming. I made the function M! in place and it helped a bit, but I am stuck at this point.

du .= 6*α*D*Q*u.*abs2.(u). You’re creating a new vector and not mutating.

Don’t recreate new solves every dt: why not use a DiscreteCallback or the integrator interface here? You’re recreating the cache every time which is what’s hurting.

Integrator interface: https://diffeq.sciml.ai/stable/basics/integrator/#Initialization-and-Stepping
Callbacks: https://diffeq.sciml.ai/stable/features/callback_functions/

This is a quick answer because I’m busy but hopefully it points in the right direction and can be cleared up in some follow-up. I just realized that the integrator interface documentation would be much better with a quick example. OrdinaryDiffEq.jl/test/integrators/iterator_tests.jl at master · SciML/OrdinaryDiffEq.jl · GitHub could be used to write such an example.

Whoops, not the first time I’ve made this mistake, thank you!

That definitely got me in the right direction, thank you Chris!

I ended up with:

function solve!(...)
    # Prepare some stuff
    # ...

    # Differential Operators
    α = 0.1
    x_order = 2
    D = CenteredDifference(1, x_order, dx, Nx) 
    Q = PeriodicBC(Float64)
    function M!(du, u,p,t)
        du .= 6*α*D*Q*u.*abs2.(u)
    end
    prob = ODEProblem(M!, ψ[:, 1], (0, dt))
    integrator = init(prob, BS3(); dt=dt,save_everystep=false) 

    for i = 1:Nt-1
        @views ψ[:, i+1] .= T(ψ[:, i], dt, M!, ...)
    end

    return nothing
end #solve

function T(ψ, dt, integrator, ...)
    # Perform some operations in-place on ψ

    # Now this is much quicker
    set_u!(integrator, ψ)
    step!(integrator)
    ψ = integrator.u

    return ψ
end

This has improved the performance by over four times! I will take a look at call backs to replace set_u!, and keep trying to further optimize this. Thanks again.