Hi everyone ! I’m new to Julia, and so far have found it very nice and neat. However, my Julia implementation is still slower than the python one I was trying to beat. The goal is to compute the mean radial pairwise velocity of objects as a function of their separation, given their 3D velocities and positions. Note that these objects are inside a periodic box, so all distances need to account for periodic boundary conditions. To avoid computing more pairwise distances than necessary, I use BallTree to filter the objects inside the maximum distance allowed.
Here are the functions I defined,
using Distances
using LinearAlgebra
using NearestNeighbors
export get_pairwise_velocity_distribution
function get_periodic_difference( x, y, period)
delta = abs(x - y)
return ifelse(delta > 0.5 * period, delta - period, delta )* sign(x-y)
function get_pairwise_velocities!(
for i in 1:length(ret)
dv[i] = vel_pair_left[i] - vel_pair_right[i]
ret[i] = get_periodic_difference(pos_pair_left[i], pos_pair_right[i], boxsize[i])
r = LinearAlgebra.norm(ret)
v_r = LinearAlgebra.dot(dv,ret)/r
#v_t = sqrt(LinearAlgebra.dot(dv, dv) - v_r*v_r)/sqrt(2.)
return r, v_r
function get_pairwise_velocity_distribution(
positions, velocities,
r_max = maximum(rbins)
metric = PeriodicEuclidean(boxsize)
tranposed_positions = permutedims(positions)
balltree = BallTree(tranposed_positions, metric)
mean_v_r = zeros(Float64, (length(rbins)-1))
n_pairs = zeros(Int32, (length(rbins)-1))
idxs = inrange(balltree, tranposed_positions, r_max, false)
ret = zeros(Float64, size(positions)[2])
dv = zeros(Float64, size(positions)[2])
for i in 1:size(positions,1)
for j in idxs[i]
if j > i
pos_i = @view positions[i,:]
pos_j = @view positions[j,:]
vel_i = @view velocities[i,:]
vel_j = @view velocities[j,:]
r, v_r = get_pairwise_velocities!(ret, dv, pos_i, pos_j, vel_i, vel_j, boxsize)
if first(rbins) < r < last(rbins)
rbin = searchsortedfirst(rbins, r) - 1
mean_v_r[rbin] += v_r
n_pairs[rbin] += 1
mean_v_r[n_pairs .> 0] = mean_v_r[n_pairs .> 0]./n_pairs[n_pairs .> 0]
return mean_v_r
You can find the package and the benchmarking scripts on my github repo https://github.com/florpi/PairVelocities
When the number of objects (here called halos) is small, the julia code wins. However, when the number of objects is larger than 10_000 the python code becomes faster by orders of magnitude (for instance for 100_000 python takes 0.42 seconds and julia 6 seconds). This is the output from @btime when I use 1_000_000 objects,
207.891 s (8775469 allocations: 7.19 GiB)
All those allocations come from calling the balltree.
Do you have any idea of how could I improve it so that it beats the python implementation? I would appreciate your help.