Hello friends, I have an ODE to solve with many different initial conditions varying in a nested loop.
In my old code, I simply define the ODEProblem
within the inner loop, which is simplest but slow. I suppose there are a ton of unnecessary memory allocations.
I have tried to speed it up, by using EnsembleProblem
of Parallel Ensemble Simulations, but surprisingly it becomes slower.
Then I tried to define the ODEProblem
within the outer loop, which is multi-threaded, and use the remake
function to update the initial conditions within the inner loop, and it does speed up substantially.
Afterwards, I tried to further speed it up by defining the ODEProblem
outside the loop, but surprisingly, it only gives a very minor additional speedup.
The following is the simplest original code:
using DifferentialEquations
using BenchmarkTools
using StaticArrays
using Base.Threads
using PyPlot
# define system parameters
const Ip = 0.5
const N = 2
const w = 0.057
const T = 2π/w
const F0 = 0.075
# define laser pulse
Fx(t) = F0 * cos(w*t/(2N))^2 * cos(w*t) * (abs(t)<N*T/2)
Fy(t) = F0 * cos(w*t/(2N))^2 * sin(w*t) * (abs(t)<N*T/2)
# define rate
W_adk(F,kd) = exp(-(2.0*(kd^2+2Ip)^1.5)/3F)
# define electron trajectory
const a = 1.0
function traj(u,p,t)
r3i = (u[1]^2+u[2]^2+a)^(-1.5)
du1 = u[3]
du2 = u[4]
du3 = -u[1]*r3i-Fx(t)
du4 = -u[2]*r3i-Fy(t)
@SVector [du1,du2,du3,du4]
end
# define result box
const Px_min = -2.0
const Px_max = 2.0
const Px_num = 200
const Px = LinRange(Px_min,Px_max,Px_num)
const dPx = (Px_max-Px_min)/(Px_num-1)
const Py_min = -2.0
const Py_max = 2.0
const Py_num = 200
const Py = LinRange(Py_min,Py_max,Py_num)
const dPy = (Py_max-Py_min)/(Py_num-1)
# define simulation parameters
const Nt = 500
const Nkd = 100
const kd_max = 2.0
# classical trajectory simulation
const Prob_nthreads = zeros(Float64,Px_num,Py_num,nthreads())
@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
Fxtr = Fx(tr)
Fytr = Fy(tr)
Ftr = hypot(Fxtr,Fytr)
phi = atan(-Fytr,-Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
tspan = (tr,N*T/2)
for kd in LinRange(-kd_max,kd_max,Nkd)
kx0 = kd*cos(phi+0.5π)
ky0 = kd*sin(phi+0.5π)
rate = W_adk(Ftr,kd)
u0 = @SVector [x0,y0,kx0,ky0]
Traj = ODEProblem(traj,u0,tspan)
sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
(x,y,px,py) = sol.u[end]
E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
E_inf >= zero(E_inf) || continue
pxidx = Int64(round((px-Px[1])/dPx))+1
pyidx = Int64(round((py-Py[1])/dPy))+1
checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
@inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
end
end
const Prob = reshape(sum(Prob_nthreads,dims=3),size(Prob_nthreads)[1:2])
# plotting
fig,ax = plt.subplots()
ax.imshow(Prob',extent=(Px[1],Px[end],Py[1],Py[end]),origin="lower")
ax.set_xlim(Px[1],Px[end])
ax.set_ylim(Py[1],Py[end])
ax.set_xlabel("\$p_x\$ (a.u.)")
ax.set_ylabel("\$p_y\$ (a.u.)")
plt.show()
By defining ODEProblem
in the outer loop, it speeds up a lot:
@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
Fxtr = Fx(tr)
Fytr = Fy(tr)
Ftr = hypot(Fxtr,Fytr)
phi = atan(-Fytr,-Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
Traj = ODEProblem(traj,zeros(Float64,4),(tr,N*T/2))
for kd in LinRange(-kd_max,kd_max,Nkd)
kx0 = kd*cos(phi+0.5π)
ky0 = kd*sin(phi+0.5π)
rate = W_adk(Ftr,kd)
u0 = @SVector [x0,y0,kx0,ky0]
Traj = remake(Traj;u0=u0)
sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
(x,y,px,py) = sol.u[end]
E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
E_inf >= zero(E_inf) || continue
pxidx = Int64(round((px-Px[1])/dPx))+1
pyidx = Int64(round((py-Py[1])/dPy))+1
checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
@inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
end
end
But if I try to define ODEProblem
outside the loop, it barely sees any increase in performance:
const Trajs = Vector{ODEProblem}(undef,nthreads())
for i=1:nthreads()
Trajs[i] = ODEProblem(traj,zeros(Float64,4),(-N*T/2,N*T/2))
end
@benchmark @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
Fxtr = Fx(tr)
Fytr = Fy(tr)
Ftr = hypot(Fxtr,Fytr)
phi = atan(-Fytr,-Fxtr)
r0 = Ip/Ftr
x0 = r0*cos(phi)
y0 = r0*sin(phi)
tspan = (tr,N*T/2)
for kd in LinRange(-kd_max,kd_max,Nkd)
kx0 = kd*cos(phi+0.5π)
ky0 = kd*sin(phi+0.5π)
rate = W_adk(Ftr,kd)
u0 = @SVector [x0,y0,kx0,ky0]
Traj = remake(Trajs[threadid()];u0=u0,tspan=tspan)
sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
(x,y,px,py) = sol.u[end]
E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
E_inf >= zero(E_inf) || continue
pxidx = Int64(round((px-Px[1])/dPx))+1
pyidx = Int64(round((py-Py[1])/dPy))+1
checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
@inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
end
end
Can anyone hint me of any possible better way to speed things up? Thanks!