Performance when using DifferentialEquations.solve repeatedly inside a loop

I am using the split step method to solve the following equation:

i \frac{\partial \psi}{\partial t} = -\frac{1}{2} \frac{\partial^2 \psi}{\partial x^2} - |\psi|^2\psi + i \alpha\left(\frac{\partial^3 \psi}{\partial x^3} + 6 |\psi|^2 \frac{\partial \psi}{\partial x} \right)

The first 3 terms are straightforward when using the split step method. For the fourth term, I basically need to solve this equation:

\frac{\partial \psi}{\partial t} = \alpha\left(6 |\psi|^2 \frac{\partial \psi}{\partial x} \right)

after evolving the equation for the first three terms. i.e. I need to repeatedly use DifferentialEquations.solve inside an outer time loop to evolve the solution a single step. This is very slow, unfortunately.

This is a simple example of how this works (I can provide a full MWE if needed but it’s on the longer side)

function solve!(...)
    # Prepare some stuff
    # ...

    # Differential Operators
    α = 0.1
    x_order = 2
    D = CenteredDifference(1, x_order, dx, Nx) 
    Q = PeriodicBC(Float64)

    function M!(du, u,p,t)
        du = 6*α*D*Q*u.*abs2.(u)
    end

    for i = 1:Nt-1
        @views ψ[:, i+1] .= T(ψ[:, i], dt, M!, ...)
    end

    return nothing
end #solve

function T(ψ, W, dt, F, F̃, M!)
    # Perform some operations in-place on ψ

    # Calling this every loop is slow and allocates a lot
    prob = ODEProblem(M!, ψ, (0.0, dt))
    sol = solve(prob, BS3(), save_everystep = false, dt = dt) 
    ψ = sol.u[end]

    return ψ
end

So as you can see, the function T evolves time by a single step, so it performs some operations on ψ then has to solve the ODE. This happens every single time step, so re-initializing the problem and then solving it makes the whole process very time consuming. I made the function M! in place and it helped a bit, but I am stuck at this point.

du .= 6*α*D*Q*u.*abs2.(u). You’re creating a new vector and not mutating.

Don’t recreate new solves every dt: why not use a DiscreteCallback or the integrator interface here? You’re recreating the cache every time which is what’s hurting.

Integrator interface: https://diffeq.sciml.ai/stable/basics/integrator/#Initialization-and-Stepping
Callbacks: https://diffeq.sciml.ai/stable/features/callback_functions/

This is a quick answer because I’m busy but hopefully it points in the right direction and can be cleared up in some follow-up. I just realized that the integrator interface documentation would be much better with a quick example. https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/test/integrators/iterator_tests.jl could be used to write such an example.

2 Likes

Whoops, not the first time I’ve made this mistake, thank you!

That definitely got me in the right direction, thank you Chris!

I ended up with:

function solve!(...)
    # Prepare some stuff
    # ...

    # Differential Operators
    α = 0.1
    x_order = 2
    D = CenteredDifference(1, x_order, dx, Nx) 
    Q = PeriodicBC(Float64)
    function M!(du, u,p,t)
        du .= 6*α*D*Q*u.*abs2.(u)
    end
    prob = ODEProblem(M!, ψ[:, 1], (0, dt))
    integrator = init(prob, BS3(); dt=dt,save_everystep=false) 

    for i = 1:Nt-1
        @views ψ[:, i+1] .= T(ψ[:, i], dt, M!, ...)
    end

    return nothing
end #solve

function T(ψ, dt, integrator, ...)
    # Perform some operations in-place on ψ

    # Now this is much quicker
    set_u!(integrator, ψ)
    step!(integrator)
    ψ = integrator.u

    return ψ
end

This has improved the performance by over four times! I will take a look at call backs to replace set_u!, and keep trying to further optimize this. Thanks again.

1 Like