Solve ODE with many different initial conditions

The original last code (a little modified):

Prob_nthreads = let Prob_nthreads = zeros(Float64,Px_num,Py_num,nthreads())
Trajs = Vector{ODEProblem}(undef,nthreads())
for i=1:nthreads()
    Trajs[i] = ODEProblem(traj,zeros(Float64,4),(-N*T/2,N*T/2))
end
@show typeof(Trajs)
@time @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
    Fxtr = Fx(tr)
    Fytr = Fy(tr)
    Ftr = hypot(Fxtr,Fytr)
    phi = atan(-Fytr,-Fxtr)
    r0 = Ip/Ftr
    x0 = r0*cos(phi)
    y0 = r0*sin(phi)
    tspan = (tr,N*T/2)
    for kd in LinRange(-kd_max,kd_max,Nkd)
        kx0 = kd*cos(phi+0.5π)
        ky0 = kd*sin(phi+0.5π)
        rate = W_adk(Ftr,kd)

        u0 = @SVector [x0,y0,kx0,ky0]
        Traj = remake(Trajs[threadid()];u0=u0,tspan=tspan)
        sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
        (x,y,px,py) = sol.u[end]
        
        E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
        E_inf >= zero(E_inf) || continue
        pxidx = Int64(round((px-Px[1])/dPx))+1
        pyidx = Int64(round((py-Py[1])/dPy))+1
        checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
        @inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
    end
end
Prob_nthreads
end

Output:

typeof(Trajs) = Vector{ODEProblem}

0.261037 seconds (3.08 M allocations: 256.142 MiB, 0.00% compilation time)

Slightly improved version:

Prob_nthreads = let Prob_nthreads = zeros(Float64,Px_num,Py_num,nthreads())
Trajs = [ODEProblem(traj,@SVector(zeros(Float64,4)),(-N*T/2,N*T/2)) for _ in 1:nthreads()]
@show typeof(Trajs)
@time @threads for tr in LinRange(-0.9N*T/2,0.9N*T/2,Nt)
    Fxtr = Fx(tr)
    Fytr = Fy(tr)
    Ftr = hypot(Fxtr,Fytr)
    phi = atan(-Fytr,-Fxtr)
    r0 = Ip/Ftr
    x0 = r0*cos(phi)
    y0 = r0*sin(phi)
    tspan = (tr,N*T/2)
    for kd in LinRange(-kd_max,kd_max,Nkd)
        kx0 = kd*cos(phi+0.5π)
        ky0 = kd*sin(phi+0.5π)
        rate = W_adk(Ftr,kd)

        u0 = @SVector [x0,y0,kx0,ky0]
        Traj = remake(Trajs[threadid()];u0=u0,tspan=tspan)
        sol = solve(Traj,Tsit5(),reltol=1e-6,save_everystep=false,save_start=false)
        (x,y,px,py) = sol.u[end]

        E_inf = 0.5*(px^2+py^2)-1/(x^2+y^2+a)
        E_inf >= zero(E_inf) || continue
        pxidx = Int64(round((px-Px[1])/dPx))+1
        pyidx = Int64(round((py-Py[1])/dPy))+1
        checkbounds(Bool,Prob_nthreads,pxidx,pyidx,threadid()) || continue
        @inbounds Prob_nthreads[pxidx,pyidx,threadid()] += rate
    end
end
Prob_nthreads
end

Output:

typeof(Trajs) = Vector{ODEProblem{SVector{4, Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, ODEFunction{false, typeof(traj), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}

0.200081 seconds (1.21 M allocations: 158.318 MiB, 0.00% compilation time)

0.261037 seconds (3.08 M allocations: 256.142 MiB, 0.00% compilation time)

0.200081 seconds (1.21 M allocations: 158.318 MiB, 0.00% compilation time)

The allocations have been reduced.

The types of Trajs in the above two codes are different. The type of the former Trajs is Vector of the abstract type ODEProblem, but the type of the latter Trajs is Vector of the more complex concrete type. If possible, it is better to avoid arrays whose eltype is an abstract type.

Furthermore, the latter uses @SVector even in the definition of Trajs so that the type of Traj is equal to that of Trajs[threadid()].

Postscript: Jupyter notebook

In the code above, I have used a let block to avoid using const. Without being a function, many useful tools like @code_warntype become difficult to use.

After testing the code above, as well as others I wrote during the trial and error process, I am again convinced that the use of const should be avoided whenever possible.

If we put the parameters describing the problem in one variable, and the parameters describing how to solve the problem in another variable, and write code that always passes them as arguments to the functions used to solve the problem, we can write code that is readable, efficient, and easy to test, without using const.

3 Likes