# Solve ODE with many different initial conditions

Hello friends, I have an ODE to solve with many different initial conditions varying in a nested loop.

In my old code, I simply define the `ODEProblem` within the inner loop, which is simplest but slow. I suppose there are a ton of unnecessary memory allocations.

I have tried to speed it up, by using `EnsembleProblem` of Parallel Ensemble Simulations, but surprisingly it becomes slower.

Then I tried to define the `ODEProblem` within the outer loop, which is multi-threaded, and use the `remake` function to update the initial conditions within the inner loop, and it does speed up substantially.

Afterwards, I tried to further speed it up by defining the `ODEProblem` outside the loop, but surprisingly, it only gives a very minor additional speedup.

The following is the simplest original code:

``````using DifferentialEquations
using BenchmarkTools
using StaticArrays
using PyPlot

# define system parameters
const Ip = 0.5
const N  = 2
const w  = 0.057
const T  = 2π/w
const F0 = 0.075

# define laser pulse
Fx(t) = F0 * cos(w*t/(2N))^2 * cos(w*t) * (abs(t)<N*T/2)
Fy(t) = F0 * cos(w*t/(2N))^2 * sin(w*t) * (abs(t)<N*T/2)

# define rate

# define electron trajectory
const a = 1.0
function traj(u,p,t)
r3i = (u^2+u^2+a)^(-1.5)
du1 = u
du2 = u
du3 = -u*r3i-Fx(t)
du4 = -u*r3i-Fy(t)
@SVector [du1,du2,du3,du4]
end

# define result box
const Px_min = -2.0
const Px_max =  2.0
const Px_num =  200
const Px  = LinRange(Px_min,Px_max,Px_num)
const dPx = (Px_max-Px_min)/(Px_num-1)

const Py_min = -2.0
const Py_max =  2.0
const Py_num =  200
const Py  = LinRange(Py_min,Py_max,Py_num)
const dPy = (Py_max-Py_min)/(Py_num-1)

# define simulation parameters
const Nt = 500
const Nkd = 100
const kd_max = 2.0

# classical trajectory simulation
@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
Fxtr = Fx(tr)
Fytr = Fy(tr)
Ftr = hypot(Fxtr,Fytr)
phi = atan(-Fytr,-Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
tspan = (tr,N*T/2)
for kd in LinRange(-kd_max,kd_max,Nkd)
kx0 = kd*cos(phi+0.5π)
ky0 = kd*sin(phi+0.5π)

u0 = @SVector [x0,y0,kx0,ky0]
Traj = ODEProblem(traj,u0,tspan)
sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
(x,y,px,py) = sol.u[end]

E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
E_inf >= zero(E_inf) || continue
pxidx = Int64(round((px-Px)/dPx))+1
pyidx = Int64(round((py-Py)/dPy))+1
end
end

# plotting
fig,ax = plt.subplots()
ax.imshow(Prob',extent=(Px,Px[end],Py,Py[end]),origin="lower")
ax.set_xlim(Px,Px[end])
ax.set_ylim(Py,Py[end])
ax.set_xlabel("\\$p_x\\$ (a.u.)")
ax.set_ylabel("\\$p_y\\$ (a.u.)")
plt.show()
``````

By defining `ODEProblem` in the outer loop, it speeds up a lot:

``````@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
Fxtr = Fx(tr)
Fytr = Fy(tr)
Ftr = hypot(Fxtr,Fytr)
phi = atan(-Fytr,-Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
Traj = ODEProblem(traj,zeros(Float64,4),(tr,N*T/2))
for kd in LinRange(-kd_max,kd_max,Nkd)
kx0 = kd*cos(phi+0.5π)
ky0 = kd*sin(phi+0.5π)

u0 = @SVector [x0,y0,kx0,ky0]
Traj = remake(Traj;u0=u0)
sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
(x,y,px,py) = sol.u[end]

E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
E_inf >= zero(E_inf) || continue
pxidx = Int64(round((px-Px)/dPx))+1
pyidx = Int64(round((py-Py)/dPy))+1
end
end
``````

But if I try to define `ODEProblem` outside the loop, it barely sees any increase in performance:

``````const Trajs = Vector{ODEProblem}(undef,nthreads())
Trajs[i] = ODEProblem(traj,zeros(Float64,4),(-N*T/2,N*T/2))
end
@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
Fxtr = Fx(tr)
Fytr = Fy(tr)
Ftr = hypot(Fxtr,Fytr)
phi = atan(-Fytr,-Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
tspan = (tr,N*T/2)
for kd in LinRange(-kd_max,kd_max,Nkd)
kx0 = kd*cos(phi+0.5π)
ky0 = kd*sin(phi+0.5π)

u0 = @SVector [x0,y0,kx0,ky0]
sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
(x,y,px,py) = sol.u[end]

E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
E_inf >= zero(E_inf) || continue
pxidx = Int64(round((px-Px)/dPx))+1
pyidx = Int64(round((py-Py)/dPy))+1
end
end
``````

Can anyone hint me of any possible better way to speed things up? Thanks!

Using the integrator form is going to be the fastest, though with static arrays the last form you did should get pretty close.

3 Likes

The original last code (a little modified):

``````Prob_nthreads = let Prob_nthreads = zeros(Float64,Px_num,Py_num,nthreads())
Trajs[i] = ODEProblem(traj,zeros(Float64,4),(-N*T/2,N*T/2))
end
@show typeof(Trajs)
@time @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
Fxtr = Fx(tr)
Fytr = Fy(tr)
Ftr = hypot(Fxtr,Fytr)
phi = atan(-Fytr,-Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
tspan = (tr,N*T/2)
for kd in LinRange(-kd_max,kd_max,Nkd)
kx0 = kd*cos(phi+0.5π)
ky0 = kd*sin(phi+0.5π)

u0 = @SVector [x0,y0,kx0,ky0]
sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
(x,y,px,py) = sol.u[end]

E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
E_inf >= zero(E_inf) || continue
pxidx = Int64(round((px-Px)/dPx))+1
pyidx = Int64(round((py-Py)/dPy))+1
end
end
end
``````

Output:

typeof(Trajs) = Vector{ODEProblem}

0.261037 seconds (3.08 M allocations: 256.142 MiB, 0.00% compilation time)

Slightly improved version:

``````Prob_nthreads = let Prob_nthreads = zeros(Float64,Px_num,Py_num,nthreads())
Trajs = [ODEProblem(traj,@SVector(zeros(Float64,4)),(-N*T/2,N*T/2)) for _ in 1:nthreads()]
@show typeof(Trajs)
@time @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
Fxtr = Fx(tr)
Fytr = Fy(tr)
Ftr = hypot(Fxtr,Fytr)
phi = atan(-Fytr,-Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
tspan = (tr,N*T/2)
for kd in LinRange(-kd_max,kd_max,Nkd)
kx0 = kd*cos(phi+0.5π)
ky0 = kd*sin(phi+0.5π)

u0 = @SVector [x0,y0,kx0,ky0]
sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
(x,y,px,py) = sol.u[end]

E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
E_inf >= zero(E_inf) || continue
pxidx = Int64(round((px-Px)/dPx))+1
pyidx = Int64(round((py-Py)/dPy))+1
end
end
end
``````

Output:

typeof(Trajs) = Vector{ODEProblem{SVector{4, Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, ODEFunction{false, typeof(traj), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}

0.200081 seconds (1.21 M allocations: 158.318 MiB, 0.00% compilation time)

0.261037 seconds (3.08 M allocations: 256.142 MiB, 0.00% compilation time)

0.200081 seconds (1.21 M allocations: 158.318 MiB, 0.00% compilation time)

The allocations have been reduced.

The types of `Trajs` in the above two codes are different. The type of the former `Trajs` is `Vector` of the abstract type `ODEProblem`, but the type of the latter `Trajs` is `Vector` of the more complex concrete type. If possible, it is better to avoid arrays whose eltype is an abstract type.

Furthermore, the latter uses `@SVector` even in the definition of `Trajs` so that the type of `Traj` is equal to that of `Trajs[threadid()]`.

Postscript: Jupyter notebook

In the code above, I have used a `let` block to avoid using `const`. Without being a function, many useful tools like `@code_warntype` become difficult to use.

After testing the code above, as well as others I wrote during the trial and error process, I am again convinced that the use of `const` should be avoided whenever possible.

If we put the parameters describing the problem in one variable, and the parameters describing how to solve the problem in another variable, and write code that always passes them as arguments to the functions used to solve the problem, we can write code that is readable, efficient, and easy to test, without using `const`.

3 Likes

In order to verify that it is not much troublesome to follow the rules below and, if we do so, the code is more readable and easier to handle, I have written such code.

Rules:

• Don’t use `const`.
• Write functions.
• Put a lot of parameters into a few variables and pass them to functions as arguments every time.
``````using DifferentialEquations
using StaticArrays
using PyPlot
using Parameters
``````
``````# define system parameters
Ip = 0.5
N  = 2
w  = 0.057
T  = 2π/w
F0 = 0.075
a = 1.0

# combine the system parameters into a single variable
p = (; Ip, N, w, T, F0, a,)
``````
``````# define laser pulse
function Fx(p, t)
@unpack F0, N, w, T = p
F0 * cos(w*t/(2N))^2 * cos(w*t) * (abs(t) < N*T/2)
end
function Fy(p, t)
@unpack F0, N, w, T = p
F0 * cos(w*t/(2N))^2 * sin(w*t) * (abs(t) < N*T/2)
end

# define electron trajectory
function traj(u, p, t)
@unpack a = p
r3i = (u^2 + u^2 + a)^(-1.5)
du1 = u
du2 = u
du3 = -u*r3i - Fx(p, t)
du4 = -u*r3i - Fy(p, t)
SVector(du1, du2, du3, du4)
end

# define rate
@unpack Ip = p
exp(-(2.0*(kd^2+2Ip)^1.5)/3F)
end
``````
``````# define result box
Px  = range(-2, 2, length = 200)
Py  = range(-2, 2, length = 200)

# define simulation parameters
simparam = (
Nt = 500,
Nkd = 100,
kd_max = 2.0,
)
``````
``````function calcprob(p, Px, Py, simparam)
@unpack Ip, N, w, T, F0, a = p
Px_num, dPx = length(Px), step(Px)
Py_num, dPy = length(Py), step(Py)
@unpack Nt, Nkd, kd_max = simparam

Trajs = [ODEProblem(traj, @SVector(zeros(4)), (-N*T/2, N*T/2), p) for _ in 1:nthreads()]

@threads for tr in range(-0.9N*T/2, 0.9N*T/2, length = Nt)
Fxtr = Fx(p, tr)
Fytr = Fy(p, tr)
Ftr = hypot(Fxtr, Fytr)
phi = atan(-Fytr, -Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
tspan = (tr, N*T/2)

for kd in range(-kd_max, kd_max, length = Nkd)
kx0 = kd*cos(phi + 0.5π)
ky0 = kd*sin(phi + 0.5π)

u0 = SVector(x0, y0, kx0, ky0)
Traj = remake(Trajs[tid]; u0=u0, tspan=tspan)
sol = solve(Traj, Tsit5(), reltol=1e-6, save_everystep=false, save_start=false)
(x, y, px, py) = sol.u[end]

E_inf = 0.5*(px^2 + py^2) - 1/(x^2 + y^2 + a)
E_inf >= zero(E_inf) || continue
pxidx = round(Int, (px - Px)/dPx) + 1
pyidx = round(Int, (py - Py)/dPy) + 1
checkbounds(Bool, Prob_nthreads, pxidx, pyidx, tid) || continue
@inbounds Prob_nthreads[pxidx, pyidx, tid] += rate
end
end
end
``````
``````# plotting
function plotprob(Px, Py, Prob)
fig, ax = plt.subplots()
ax.imshow(Prob', extent=(Px, Px[end], Py, Py[end]), origin="lower")
ax.set_xlim(Px, Px[end])
ax.set_ylim(Py, Py[end])
ax.set_xlabel("\\$p_x\\$ (a.u.)")
ax.set_ylabel("\\$p_y\\$ (a.u.)")
#plt.show()
end
``````
``````@time Prob = calcprob(p, Px, Py, simparam)
@time Prob = calcprob(p, Px, Py, simparam)
@time Prob = calcprob(p, Px, Py, simparam)
plotprob(Px, Py, Prob)
``````

Result:

``````  4.674843 seconds (14.10 M allocations: 946.903 MiB, 4.23% gc time, 96.75% compilation time)
0.204111 seconds (1.00 M allocations: 155.970 MiB, 19.95% gc time)
0.146866 seconds (1.00 M allocations: 155.970 MiB)
`````` 0.146866 seconds (1.00 M allocations: 155.970 MiB)

If we stop defining parameters with `const`, we can change their values and types freely and safely, making it easier to test the code and do trial and error.

If we pass the necessary parameters to the function as arguments every time, it becomes much easier to read the code, since the dependencies of the function are limited to what is described in that function.

The above are not limited to Julia, but can also be applied when writing code in C/C++ or Fortran.

In the case of Julia, if we follow the rules, the performance is often higher, as shown above.

4 Likes

@genkuroki Thanks a lot for your kind instructions, it has been a great help, and I appreciate your warmheartedness very much!