Solve ODE with many different initial conditions

Hello friends, I have an ODE to solve with many different initial conditions varying in a nested loop.

In my old code, I simply define the ODEProblem within the inner loop, which is simplest but slow. I suppose there are a ton of unnecessary memory allocations.

I have tried to speed it up, by using EnsembleProblem of Parallel Ensemble Simulations, but surprisingly it becomes slower.

Then I tried to define the ODEProblem within the outer loop, which is multi-threaded, and use the remake function to update the initial conditions within the inner loop, and it does speed up substantially.

Afterwards, I tried to further speed it up by defining the ODEProblem outside the loop, but surprisingly, it only gives a very minor additional speedup.

The following is the simplest original code:

using DifferentialEquations
using BenchmarkTools
using StaticArrays
using Base.Threads
using PyPlot

# define system parameters
const Ip = 0.5
const N  = 2
const w  = 0.057
const T  = 2π/w
const F0 = 0.075

# define laser pulse
Fx(t) = F0 * cos(w*t/(2N))^2 * cos(w*t) * (abs(t)<N*T/2)
Fy(t) = F0 * cos(w*t/(2N))^2 * sin(w*t) * (abs(t)<N*T/2)

# define rate
W_adk(F,kd) = exp(-(2.0*(kd^2+2Ip)^1.5)/3F)

# define electron trajectory
const a = 1.0
function traj(u,p,t)
    r3i = (u[1]^2+u[2]^2+a)^(-1.5)
    du1 = u[3]
    du2 = u[4]
    du3 = -u[1]*r3i-Fx(t)
    du4 = -u[2]*r3i-Fy(t)
    @SVector [du1,du2,du3,du4]
end

# define result box
const Px_min = -2.0
const Px_max =  2.0
const Px_num =  200
const Px  = LinRange(Px_min,Px_max,Px_num)
const dPx = (Px_max-Px_min)/(Px_num-1)

const Py_min = -2.0
const Py_max =  2.0
const Py_num =  200
const Py  = LinRange(Py_min,Py_max,Py_num)
const dPy = (Py_max-Py_min)/(Py_num-1)

# define simulation parameters
const Nt = 500
const Nkd = 100
const kd_max = 2.0

# classical trajectory simulation
const Prob_nthreads = zeros(Float64,Px_num,Py_num,nthreads())
@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
    Fxtr = Fx(tr)
    Fytr = Fy(tr)
    Ftr = hypot(Fxtr,Fytr)
    phi = atan(-Fytr,-Fxtr)
    r0 = Ip/Ftr
    x0 = r0*cos(phi)
    y0 = r0*sin(phi)
    tspan = (tr,N*T/2)
    for kd in LinRange(-kd_max,kd_max,Nkd)
        kx0 = kd*cos(phi+0.5π)
        ky0 = kd*sin(phi+0.5π)
        rate = W_adk(Ftr,kd)

        u0 = @SVector [x0,y0,kx0,ky0]
        Traj = ODEProblem(traj,u0,tspan)
        sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
        (x,y,px,py) = sol.u[end]

        E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
        E_inf >= zero(E_inf) || continue
        pxidx = Int64(round((px-Px[1])/dPx))+1
        pyidx = Int64(round((py-Py[1])/dPy))+1
        checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
        @inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
    end
end
const Prob = reshape(sum(Prob_nthreads,dims=3),size(Prob_nthreads)[1:2])

# plotting
fig,ax = plt.subplots()
ax.imshow(Prob',extent=(Px[1],Px[end],Py[1],Py[end]),origin="lower")
ax.set_xlim(Px[1],Px[end])
ax.set_ylim(Py[1],Py[end])
ax.set_xlabel("\$p_x\$ (a.u.)")
ax.set_ylabel("\$p_y\$ (a.u.)")
plt.show()

By defining ODEProblem in the outer loop, it speeds up a lot:

@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
    Fxtr = Fx(tr)
    Fytr = Fy(tr)
    Ftr = hypot(Fxtr,Fytr)
    phi = atan(-Fytr,-Fxtr)
    r0 = Ip/Ftr
    x0 = r0*cos(phi)
    y0 = r0*sin(phi)
    Traj = ODEProblem(traj,zeros(Float64,4),(tr,N*T/2))
    for kd in LinRange(-kd_max,kd_max,Nkd)
        kx0 = kd*cos(phi+0.5π)
        ky0 = kd*sin(phi+0.5π)
        rate = W_adk(Ftr,kd)

        u0 = @SVector [x0,y0,kx0,ky0]
        Traj = remake(Traj;u0=u0)
        sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
        (x,y,px,py) = sol.u[end]

        E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
        E_inf >= zero(E_inf) || continue
        pxidx = Int64(round((px-Px[1])/dPx))+1
        pyidx = Int64(round((py-Py[1])/dPy))+1
        checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
        @inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
    end
end

But if I try to define ODEProblem outside the loop, it barely sees any increase in performance:

const Trajs = Vector{ODEProblem}(undef,nthreads())
for i=1:nthreads()
    Trajs[i] = ODEProblem(traj,zeros(Float64,4),(-N*T/2,N*T/2))
end
@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
    Fxtr = Fx(tr)
    Fytr = Fy(tr)
    Ftr = hypot(Fxtr,Fytr)
    phi = atan(-Fytr,-Fxtr)
    r0 = Ip/Ftr
    x0 = r0*cos(phi)
    y0 = r0*sin(phi)
    tspan = (tr,N*T/2)
    for kd in LinRange(-kd_max,kd_max,Nkd)
        kx0 = kd*cos(phi+0.5π)
        ky0 = kd*sin(phi+0.5π)
        rate = W_adk(Ftr,kd)

        u0 = @SVector [x0,y0,kx0,ky0]
        Traj = remake(Trajs[threadid()];u0=u0,tspan=tspan)
        sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
        (x,y,px,py) = sol.u[end]

        E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
        E_inf >= zero(E_inf) || continue
        pxidx = Int64(round((px-Px[1])/dPx))+1
        pyidx = Int64(round((py-Py[1])/dPy))+1
        checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
        @inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
    end
end

Can anyone hint me of any possible better way to speed things up? Thanks!

Using the integrator form is going to be the fastest, though with static arrays the last form you did should get pretty close.

3 Likes

The original last code (a little modified):

Prob_nthreads = let Prob_nthreads = zeros(Float64,Px_num,Py_num,nthreads())
Trajs = Vector{ODEProblem}(undef,nthreads())
for i=1:nthreads()
    Trajs[i] = ODEProblem(traj,zeros(Float64,4),(-N*T/2,N*T/2))
end
@show typeof(Trajs)
@time @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
    Fxtr = Fx(tr)
    Fytr = Fy(tr)
    Ftr = hypot(Fxtr,Fytr)
    phi = atan(-Fytr,-Fxtr)
    r0 = Ip/Ftr
    x0 = r0*cos(phi)
    y0 = r0*sin(phi)
    tspan = (tr,N*T/2)
    for kd in LinRange(-kd_max,kd_max,Nkd)
        kx0 = kd*cos(phi+0.5π)
        ky0 = kd*sin(phi+0.5π)
        rate = W_adk(Ftr,kd)

        u0 = @SVector [x0,y0,kx0,ky0]
        Traj = remake(Trajs[threadid()];u0=u0,tspan=tspan)
        sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
        (x,y,px,py) = sol.u[end]
        
        E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
        E_inf >= zero(E_inf) || continue
        pxidx = Int64(round((px-Px[1])/dPx))+1
        pyidx = Int64(round((py-Py[1])/dPy))+1
        checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
        @inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
    end
end
Prob_nthreads
end

Output:

typeof(Trajs) = Vector{ODEProblem}

0.261037 seconds (3.08 M allocations: 256.142 MiB, 0.00% compilation time)

Slightly improved version:

Prob_nthreads = let Prob_nthreads = zeros(Float64,Px_num,Py_num,nthreads())
Trajs = [ODEProblem(traj,@SVector(zeros(Float64,4)),(-N*T/2,N*T/2)) for _ in 1:nthreads()]
@show typeof(Trajs)
@time @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
    Fxtr = Fx(tr)
    Fytr = Fy(tr)
    Ftr = hypot(Fxtr,Fytr)
    phi = atan(-Fytr,-Fxtr)
    r0 = Ip/Ftr
    x0 = r0*cos(phi)
    y0 = r0*sin(phi)
    tspan = (tr,N*T/2)
    for kd in LinRange(-kd_max,kd_max,Nkd)
        kx0 = kd*cos(phi+0.5π)
        ky0 = kd*sin(phi+0.5π)
        rate = W_adk(Ftr,kd)

        u0 = @SVector [x0,y0,kx0,ky0]
        Traj = remake(Trajs[threadid()];u0=u0,tspan=tspan)
        sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
        (x,y,px,py) = sol.u[end]

        E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
        E_inf >= zero(E_inf) || continue
        pxidx = Int64(round((px-Px[1])/dPx))+1
        pyidx = Int64(round((py-Py[1])/dPy))+1
        checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
        @inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
    end
end
Prob_nthreads
end

Output:

typeof(Trajs) = Vector{ODEProblem{SVector{4, Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, ODEFunction{false, typeof(traj), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}

0.200081 seconds (1.21 M allocations: 158.318 MiB, 0.00% compilation time)

0.261037 seconds (3.08 M allocations: 256.142 MiB, 0.00% compilation time)

0.200081 seconds (1.21 M allocations: 158.318 MiB, 0.00% compilation time)

The allocations have been reduced.

The types of Trajs in the above two codes are different. The type of the former Trajs is Vector of the abstract type ODEProblem, but the type of the latter Trajs is Vector of the more complex concrete type. If possible, it is better to avoid arrays whose eltype is an abstract type.

Furthermore, the latter uses @SVector even in the definition of Trajs so that the type of Traj is equal to that of Trajs[threadid()].

Postscript: Jupyter notebook

In the code above, I have used a let block to avoid using const. Without being a function, many useful tools like @code_warntype become difficult to use.

After testing the code above, as well as others I wrote during the trial and error process, I am again convinced that the use of const should be avoided whenever possible.

If we put the parameters describing the problem in one variable, and the parameters describing how to solve the problem in another variable, and write code that always passes them as arguments to the functions used to solve the problem, we can write code that is readable, efficient, and easy to test, without using const.

3 Likes

In order to verify that it is not much troublesome to follow the rules below and, if we do so, the code is more readable and easier to handle, I have written such code.

Rules:

  • Don’t use const.
  • Write functions.
  • Put a lot of parameters into a few variables and pass them to functions as arguments every time.

Jupyter notebook: https://github.com/genkuroki/public/blob/main/0018/Solve%20ODE%20with%20many%20different%20initial%20conditions%20Part%202.ipynb

using DifferentialEquations
using StaticArrays
using Base.Threads
using PyPlot
using Parameters
# define system parameters
Ip = 0.5
N  = 2
w  = 0.057
T  = 2π/w
F0 = 0.075
a = 1.0

# combine the system parameters into a single variable
p = (; Ip, N, w, T, F0, a,)
# define laser pulse
function Fx(p, t)
    @unpack F0, N, w, T = p
    F0 * cos(w*t/(2N))^2 * cos(w*t) * (abs(t) < N*T/2)
end
function Fy(p, t)
    @unpack F0, N, w, T = p
    F0 * cos(w*t/(2N))^2 * sin(w*t) * (abs(t) < N*T/2)
end

# define electron trajectory
function traj(u, p, t)
    @unpack a = p
    r3i = (u[1]^2 + u[2]^2 + a)^(-1.5)
    du1 = u[3]
    du2 = u[4]
    du3 = -u[1]*r3i - Fx(p, t)
    du4 = -u[2]*r3i - Fy(p, t)
    SVector(du1, du2, du3, du4)
end

# define rate
function W_adk(p, F, kd)
    @unpack Ip = p
    exp(-(2.0*(kd^2+2Ip)^1.5)/3F)
end
# define result box
Px  = range(-2, 2, length = 200)
Py  = range(-2, 2, length = 200)
    
# define simulation parameters
simparam = (
    Nt = 500,
    Nkd = 100,
    kd_max = 2.0,
)
function calcprob(p, Px, Py, simparam)
    @unpack Ip, N, w, T, F0, a = p
    Px_num, dPx = length(Px), step(Px)
    Py_num, dPy = length(Py), step(Py)
    @unpack Nt, Nkd, kd_max = simparam
    
    Prob_nthreads = zeros(Px_num, Py_num, nthreads())
    Trajs = [ODEProblem(traj, @SVector(zeros(4)), (-N*T/2, N*T/2), p) for _ in 1:nthreads()]
    
    @threads for tr in range(-0.9N*T/2, 0.9N*T/2, length = Nt)
        tid = threadid()
        Fxtr = Fx(p, tr)
        Fytr = Fy(p, tr)
        Ftr = hypot(Fxtr, Fytr)
        phi = atan(-Fytr, -Fxtr)
        r0 = Ip/Ftr
        x0 = r0*cos(phi)
        y0 = r0*sin(phi)
        tspan = (tr, N*T/2)
        
        for kd in range(-kd_max, kd_max, length = Nkd)
            kx0 = kd*cos(phi + 0.5π)
            ky0 = kd*sin(phi + 0.5π)
            rate = W_adk(p, Ftr, kd)

            u0 = SVector(x0, y0, kx0, ky0)
            Traj = remake(Trajs[tid]; u0=u0, tspan=tspan)
            sol = solve(Traj, Tsit5(), reltol=1e-6, save_everystep=false, save_start=false)
            (x, y, px, py) = sol.u[end]

            E_inf = 0.5*(px^2 + py^2) - 1/(x^2 + y^2 + a)
            E_inf >= zero(E_inf) || continue
            pxidx = round(Int, (px - Px[1])/dPx) + 1
            pyidx = round(Int, (py - Py[1])/dPy) + 1
            checkbounds(Bool, Prob_nthreads, pxidx, pyidx, tid) || continue
            @inbounds Prob_nthreads[pxidx, pyidx, tid] += rate
        end
    end
    Prob = reshape(sum(Prob_nthreads, dims=3), size(Prob_nthreads)[1:2])
end
# plotting
function plotprob(Px, Py, Prob)
    fig, ax = plt.subplots()
    ax.imshow(Prob', extent=(Px[1], Px[end], Py[1], Py[end]), origin="lower")
    ax.set_xlim(Px[1], Px[end])
    ax.set_ylim(Py[1], Py[end])
    ax.set_xlabel("\$p_x\$ (a.u.)")
    ax.set_ylabel("\$p_y\$ (a.u.)")
    #plt.show()
end
@time Prob = calcprob(p, Px, Py, simparam)
@time Prob = calcprob(p, Px, Py, simparam)
@time Prob = calcprob(p, Px, Py, simparam)
plotprob(Px, Py, Prob)

Result:

  4.674843 seconds (14.10 M allocations: 946.903 MiB, 4.23% gc time, 96.75% compilation time)
  0.204111 seconds (1.00 M allocations: 155.970 MiB, 19.95% gc time)
  0.146866 seconds (1.00 M allocations: 155.970 MiB)

Prob


0.146866 seconds (1.00 M allocations: 155.970 MiB)

If we stop defining parameters with const, we can change their values and types freely and safely, making it easier to test the code and do trial and error.

If we pass the necessary parameters to the function as arguments every time, it becomes much easier to read the code, since the dependencies of the function are limited to what is described in that function.

The above are not limited to Julia, but can also be applied when writing code in C/C++ or Fortran.

In the case of Julia, if we follow the rules, the performance is often higher, as shown above.

4 Likes

@genkuroki Thanks a lot for your kind instructions, it has been a great help, and I appreciate your warmheartedness very much!