Pairwise computation slower than Python (Cython) code (BallTree very slow!)

Hi everyone ! I’m new to Julia, and so far have found it very nice and neat. However, my Julia implementation is still slower than the python one I was trying to beat. The goal is to compute the mean radial pairwise velocity of objects as a function of their separation, given their 3D velocities and positions. Note that these objects are inside a periodic box, so all distances need to account for periodic boundary conditions. To avoid computing more pairwise distances than necessary, I use BallTree to filter the objects inside the maximum distance allowed.

Here are the functions I defined,

using Distances
using LinearAlgebra
using NearestNeighbors

export get_pairwise_velocity_distribution

function get_periodic_difference( x, y, period)
    delta = abs(x - y)
    return ifelse(delta > 0.5 * period, delta - period, delta )* sign(x-y)
end

function get_pairwise_velocities!(
                    ret, 
                    dv, 
                    pos_pair_left,
                    pos_pair_right,
                    vel_pair_left,
                    vel_pair_right,
                    boxsize
    )
    for i in 1:length(ret)
        dv[i] = vel_pair_left[i] - vel_pair_right[i]
        ret[i] = get_periodic_difference(pos_pair_left[i], pos_pair_right[i], boxsize[i])
    end
    r = LinearAlgebra.norm(ret)
    v_r = LinearAlgebra.dot(dv,ret)/r
    #v_t = sqrt(LinearAlgebra.dot(dv, dv) - v_r*v_r)/sqrt(2.)
    return r, v_r
end

function get_pairwise_velocity_distribution(
        positions, velocities,
        rbins,
        boxsize
        )
    r_max = maximum(rbins)
    metric = PeriodicEuclidean(boxsize)
    tranposed_positions = permutedims(positions)
    balltree = BallTree(tranposed_positions, metric)
    mean_v_r = zeros(Float64, (length(rbins)-1))
    n_pairs = zeros(Int32, (length(rbins)-1))
    idxs = inrange(balltree, tranposed_positions, r_max, false)
    ret = zeros(Float64, size(positions)[2])
    dv = zeros(Float64, size(positions)[2])
    for i in 1:size(positions,1) 
        for j in idxs[i]
            if j > i
                pos_i = @view positions[i,:]
                pos_j = @view positions[j,:]
                vel_i = @view velocities[i,:]
                vel_j = @view velocities[j,:]
                r, v_r = get_pairwise_velocities!(ret, dv, pos_i, pos_j, vel_i, vel_j, boxsize)
                if first(rbins) < r < last(rbins)
                    rbin = searchsortedfirst(rbins, r) - 1
                    mean_v_r[rbin] += v_r
                    n_pairs[rbin] += 1
                end
            end
        end
    end
    mean_v_r[n_pairs .> 0] = mean_v_r[n_pairs .> 0]./n_pairs[n_pairs .> 0]
    return mean_v_r
end

You can find the package and the benchmarking scripts on my github repo https://github.com/florpi/PairVelocities

When the number of objects (here called halos) is small, the julia code wins. However, when the number of objects is larger than 10_000 the python code becomes faster by orders of magnitude (for instance for 100_000 python takes 0.42 seconds and julia 6 seconds). This is the output from @btime when I use 1_000_000 objects,

207.891 s (8775469 allocations: 7.19 GiB)

All those allocations come from calling the balltree.

Do you have any idea of how could I improve it so that it beats the python implementation? I would appreciate your help.

2 Likes

What about changing this to something like:

rbin = floor(Int,r/rbins)

I cannot test it now, but I have the impression that you do not need to search there.

Do you know if it is the loop or the ball tree that is taking most of the time? I don’t know if NearestNeighbour is the most appropriate package to compute what you want.

1 Like

I think searchsorted is necessary for the code to work with any type of rbins (logarithmic space or whatever the user wants)

Regarding the question on balltree, the calls to balltree and inrange take about 6.813 seconds, when using 100000 objects, whereas the full code takes about 7 seconds. Therefore the answer is yes, most of the time is spent in balltree. Do you know of any alternatives I could use to efficiently extract all objects within a maximum distance for every object?

2 Likes

I know of this package below. But it is still somewhat experimental. It is possible that Nearest Neighbor is usable, but with a different calling strategy.

https://github.com/JuliaMolSim/NeighbourLists.jl

I am not sure if we have a state of the art implementation of cell lists yet in Julia. When I had to use them I implemented them myself, but only for the limited type of boxes and distance measures I was interested.

1 Like

Thank you very much for your answer! I’ll have a look at your implementation :slight_smile:

I found this now: https://github.com/JuliaNeighbors

Maybe some of those packages helps.

(The python implementation seems to be calling some package. It might be a very efficient implementation of what you want already. If there is no better alternative and you want to keep the rest of the code in Julia perhaps you can use PyCall, or call the lower level routine directly if clear interfaces are available).

Yes I think the python implementation is very optimized, however I need to modify it to do some other thing I’d like to do, and modifying that cython code seems more daunting than optimizing julia :grimacing:

1 Like

Where is the cython code? I could not find it.

It’s here https://github.com/astropy/halotools/blob/v0.7/halotools/mock_observables/pairwise_velocities/engines/mean_radial_velocity_vs_r_engine.pyx

1 Like

What is there is pretty much the same I have implemented. It seems complicated only because of the type conversions at the beginning, but those would not be necessary in Julia. The package I mentioned above first is similar, maybe even faster, and deals with general periodic boundaries, as far as I know.

Is this the package you mean https://github.com/JuliaMolSim/NeighbourLists.jl ?

1 Like

The code start here :
https://github.com/astropy/halotools/blob/v0.7/halotools/mock_observables/pairwise_velocities/mean_radial_velocity_vs_r.py#L28
It first subdivide the space in cells forming RectangularDoubleMesh. The cells size are based on bins sizes.
Calculations is then performed for each of these cells.
Computing which object belongs to which cells is probably much faster than building the ballTree

1 Like

That one.

I will try to split what I did in a separate package, at least to compare it. If it is not as fast is better that I know it.

2 Likes

I have put up a package with my implementation of Cell Lists here:

I am not sure if that is exactly what you are searching for. You can pass any function to the pairwise calculation to compute histograms of properties that are distance-dependent, or an average distance, a potential, update a force vector, or anything like that. There are some examples in the README file.

Concerning the performance: for 100_000, a box side of 250 (cubic in this case) and a cutoff of 10, which appears to be what you had tested, I can build an histogram of the distances (within 10, of course), in:

julia> @btime CellLists.test2(100_000)
  375.571 ms (38 allocations: 4.58 MiB)

which is roughly similar to what you reported for the Python/Cython implementation there (although I have no idea how our computers compare).

But the implementation is threaded, and it scales quite decently. My laptop has 4 cores, and if I run with julia -t4 I get:

julia> @btime CellLists.test2(100_000)
  127.368 ms (68 allocations: 4.58 MiB)

using -t8 there is a small additional speedup:

julia> @btime CellLists.test2(100_000)
  95.421 ms (104 allocations: 4.59 MiB)

I hope that is useful. My goal is to make the implementation generic enough to any type of periodic boundary (currently it only works with orthorhombic boxes) and maybe for any dimension (although I not sure if that is really useful).

Anyway, any contribution to improve that will be mostly welcome.

5 Likes

There’s also GitHub - jaantollander/CellLists.jl: Julia language implementation of the Cell Lists algorithm to solve the fixed-radius near neighbors problem including serial and multithreaded algorithms.

3 Likes

Realize that Julia (unlike Python) is column-major so that this represents a discontinuous slice in memory.

This kind of code is a perfect case to use StaticArrays.jl for the coordinate vectors. (Instead of a 2d array, have a 1d array of SVector values. Then you can also use vector operations and it will be fast inlined) as well as shorter.

9 Likes

Just to inform that I have already posted a breaking change :grimacing: . The inner function receives now the indexes of the particles and the squared distance, which is already computed, and can be used inside. I’ve added examples in which I compute a gravitational potential and a force with that.

(In all the examples I used static arrays, as suggested by Steven).

and I will be probably change the name of the repository, of course.

Updated the name to CellListMap (https://github.com/m3g/CellListMap.jl)

2 Likes

Thank you so much ! It looks really great :grin: I found some time today to adapt my code to your package and compare it to my old implementation. It is already a factor of 3 faster :dancer: However, I think I might be doing something wrong since the number of allocations I find is quite large compared to what you posted.

This is what I found for 100_000 objects,

 Old code: 6.823 s (517132 allocations: 80.69 MiB)
New code (using CellListMap) 2.335 s (10762800 allocations: 380.29 MiB)

And this is how I adapted your package to my example,

function get_periodic_difference_cell_lists( x, y, period)
    delta = abs.(x - y)
    return @. ifelse(delta > 0.5 * period, delta - period, delta )* sign(x-y)
end

function compute_pairwise_mean_cell_lists!(x,y,i,j,d2,hist,velocities, rbins,sides)
    d = get_periodic_difference_cell_lists(x,y,sides)
    r = LinearAlgebra.norm(d)
    ibin = searchsortedfirst(rbins, r) - 1
    hist[1][ibin] += 1
    hist[2][ibin] += LinearAlgebra.dot(velocities[i] - velocities[j],d)/r
    return hist
end

function get_pairwise_velocity_radial_mean_cell_lists(
        positions, velocities,
        rbins,
        boxsize
        )
    n = size(positions)[1]
    r_max = maximum(rbins)
    lc = LinkedLists(n)
    box = Box(boxsize, r_max)
    positions = [SVector{3, Float64}(positions[i,:]) for i in 1:n]
    velocities = [SVector{3, Float64}(velocities[i,:]) for i in 1:n]
    initlists!(positions,box,lc)
    hist = (zeros(Int,length(rbins)-1), zeros(Float64,length(rbins)-1))
    hist = map_pairwise!(
            (x,y,i,j,d2,hist) -> compute_pairwise_mean_cell_lists!(x,y,i,j,d2,hist,velocities, rbins, boxsize),
            hist, positions, box, lc,
    )
    n_pairs = hist[1]
    mean_v_r = hist[2]
    mean_v_r[n_pairs .> 0] = mean_v_r[n_pairs .> 0]./n_pairs[n_pairs .> 0]
    return mean_v_r
end

Note that d2 in the input parameters is already the squared distance between x and y consider the periodic boundary conditions, thus you can use here simply d = sqrt(d2).

(I guess the allocations are coming from

because x-y is probably allocating a vector there. One way to get around that kind of problem is to use StaticArrays:

julia> f(x,y) = abs.(x - y)
f (generic function with 1 method)

julia> @btime f(a,b) setup=(a=rand(3);b=rand(3)) evals=1
  96.000 ns (2 allocations: 224 bytes)
3-element Vector{Float64}:
 0.06231306984947671
 0.2947003877865333
 0.21740617333782897

julia> using StaticArrays

julia> @btime f(a,b) setup=(a=@SVector(rand(3));b=@SVector(rand(3))) evals=1
  23.000 ns (0 allocations: 0 bytes)
3-element SVector{3, Float64} with indices SOneTo(3):
 0.18450470577908096
 0.1602097243490066
 0.26416313723532725


ps: If you can provide a running example that I can copy/paste, I could test things more precisely maybe.

4 Likes

That’s one, the broadcast is another one (i.e.abs.(..) has to allocate a vector to store the result). Another one is probably coming from the final broadcast via @..

I’d write this function in a non-vectorized form:

function get_periodic_difference_cell(x, y, period)
    delta = abs(x - y)
    return ifelse(delta > 0.5 * period, delta - period, delta )* sign(x-y)
end

and broadcast the application of get_periodic_difference_cell instead. That way you only have one allocation, for holding the result of get_periodic_difference_cell.

This is a common pattern in julia, to write a kernel that’s working on the smallest possible data and “broadcast” it across all indices.

3 Likes