Hi everyone ! I’m new to Julia, and so far have found it very nice and neat. However, my Julia implementation is still slower than the python one I was trying to beat. The goal is to compute the mean radial pairwise velocity of objects as a function of their separation, given their 3D velocities and positions. Note that these objects are inside a periodic box, so all distances need to account for periodic boundary conditions. To avoid computing more pairwise distances than necessary, I use BallTree to filter the objects inside the maximum distance allowed.
Here are the functions I defined,
using Distances using LinearAlgebra using NearestNeighbors export get_pairwise_velocity_distribution function get_periodic_difference( x, y, period) delta = abs(x - y) return ifelse(delta > 0.5 * period, delta - period, delta )* sign(x-y) end function get_pairwise_velocities!( ret, dv, pos_pair_left, pos_pair_right, vel_pair_left, vel_pair_right, boxsize ) for i in 1:length(ret) dv[i] = vel_pair_left[i] - vel_pair_right[i] ret[i] = get_periodic_difference(pos_pair_left[i], pos_pair_right[i], boxsize[i]) end r = LinearAlgebra.norm(ret) v_r = LinearAlgebra.dot(dv,ret)/r #v_t = sqrt(LinearAlgebra.dot(dv, dv) - v_r*v_r)/sqrt(2.) return r, v_r end function get_pairwise_velocity_distribution( positions, velocities, rbins, boxsize ) r_max = maximum(rbins) metric = PeriodicEuclidean(boxsize) tranposed_positions = permutedims(positions) balltree = BallTree(tranposed_positions, metric) mean_v_r = zeros(Float64, (length(rbins)-1)) n_pairs = zeros(Int32, (length(rbins)-1)) idxs = inrange(balltree, tranposed_positions, r_max, false) ret = zeros(Float64, size(positions)) dv = zeros(Float64, size(positions)) for i in 1:size(positions,1) for j in idxs[i] if j > i pos_i = @view positions[i,:] pos_j = @view positions[j,:] vel_i = @view velocities[i,:] vel_j = @view velocities[j,:] r, v_r = get_pairwise_velocities!(ret, dv, pos_i, pos_j, vel_i, vel_j, boxsize) if first(rbins) < r < last(rbins) rbin = searchsortedfirst(rbins, r) - 1 mean_v_r[rbin] += v_r n_pairs[rbin] += 1 end end end end mean_v_r[n_pairs .> 0] = mean_v_r[n_pairs .> 0]./n_pairs[n_pairs .> 0] return mean_v_r end
You can find the package and the benchmarking scripts on my github repo https://github.com/florpi/PairVelocities
When the number of objects (here called halos) is small, the julia code wins. However, when the number of objects is larger than 10_000 the python code becomes faster by orders of magnitude (for instance for 100_000 python takes 0.42 seconds and julia 6 seconds). This is the output from @btime when I use 1_000_000 objects,
207.891 s (8775469 allocations: 7.19 GiB)
All those allocations come from calling the balltree.
Do you have any idea of how could I improve it so that it beats the python implementation? I would appreciate your help.