On my machine (Ivy Bridge i7-3770, DDR3-1333 RAM), tullio version is 4x slower than @DNF’s with @avx
. Maybe that’s hardware issue.
Full results:
Function definitions
using Pkg; pkg"add LoopVectorization https://github.com/mcabbott/Tullio.jl"
using Base.Threads, LoopVectorization, Tullio
function distances(data1, data2)
data1 = deg2rad.(data1)
data2 = deg2rad.(data2)
lat1 = @view data1[:, 1]
lng1 = @view data1[:, 2]
lat2 = @view data2[:, 1]
lng2 = @view data2[:, 2]
diff_lat = @view(lat1[:, None]) .- @view(lat2[None, :])
diff_lng = @view(lng1[:, None]) .- @view(lng2[None, :])
data = (
@. sin(diff_lat / 2) ^ 2 +
cos(@view(lat1[:, None])) * cos(@view(lat2[None,:])) * sin(diff_lng / 2) ^ 2
)
data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))
return reshape(data, (size(data1, 1), size(data2, 1)))
end
function distances_threaded(data1, data2)
lat1 = [deg2rad(data1[i,1]) for i in 1:size(data1,1)]
lng1 = [deg2rad(data1[i,2]) for i in 1:size(data1,1)]
lat2 = [deg2rad(data2[i,1]) for i in 1:size(data2,1)]
lng2 = [deg2rad(data2[i,2]) for i in 1:size(data2,1)]
data = Matrix{Float64}(undef, length(lat1), length(lat2))
@threads for i in eachindex(lat2)
lat, lng = lat2[i], lng2[i]
data[:, i] .= @. sin((lat1 - lat) / 2) ^ 2 + cos(lat1) * cos(lat) * sin((lng1 - lng) / 2) ^ 2
end
@threads for i in eachindex(data)
data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
end
return data
end
function distances_threaded_simd(data1, data2) # @baggepinnen
lat1 = [deg2rad(data1[i,1]) for i in 1:size(data1,1)]
lng1 = [deg2rad(data1[i,2]) for i in 1:size(data1,1)]
lat2 = [deg2rad(data2[i,1]) for i in 1:size(data2,1)]
lng2 = [deg2rad(data2[i,2]) for i in 1:size(data2,1)]
data = Matrix{Float64}(undef, length(lat1), length(lat2))
@threads for i in eachindex(lat2)
lat, lng = lat2[i], lng2[i]
@avx data[:, i] .= @. sin((lat1 - lat) / 2) ^ 2 + cos(lat1) * cos(lat) * sin((lng1 - lng) / 2) ^ 2
end
@threads for i in eachindex(data)
@avx data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
end
return data
end
function distances_bcast(data1, data2) # @DNF
data1 = deg2rad.(data1)
data2 = deg2rad.(data2)
lat1 = @view data1[:, 1]
lng1 = @view data1[:, 2]
lat2 = @view data2[:, 1]
lng2 = @view data2[:, 2]
data = sin.((lat1 .- lat2')./2).^2 .+ cos.(lat1) .* cos.(lat2') .* sin.((lng1 .- lng2')./2).^2
@. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
return data
end
function distances_bcast_simd(data1, data2)
data1 = deg2rad.(data1)
data2 = deg2rad.(data2)
lat1 = @view data1[:, 1]
lng1 = @view data1[:, 2]
lat2 = @view data2[:, 1]
lng2 = @view data2[:, 2]
@avx data = sin.((lat1 .- lat2')./2).^2 .+ cos.(lat1) .* cos.(lat2') .* sin.((lng1 .- lng2')./2).^2
@. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
return data
end
function distances_tullio(data1deg, data2deg)
@tullio data1[n,c] := data1deg[n,c] * (2pi/360)
@tullio data2[n,c] := data2deg[n,c] * (2pi/360)
@tullio data[n,m] := sin((data1[n,1] - data2[m,1])/2)^2 +
cos(data1[n,1]) * cos(data2[m,1]) * sin((data1[n,2] - data2[m,2])/2)^2
@tullio data[n,m] = 2 * 6373 * atan(sqrt(abs(data[n,m])), sqrt(abs(1 - data[n,m])))
end
using PyCall
py"""
import typing as tp
from jax import numpy as jnp
import jax
import numpy as np
import time
@jax.jit
def distances_jax(data1, data2):
# data1, data2 are the data arrays with 2 cols and they hold
# lat., lng. values in those cols respectively
np = jnp
data1 = np.deg2rad(data1)
data2 = np.deg2rad(data2)
lat1 = data1[:, 0]
lng1 = data1[:, 1]
lat2 = data2[:, 0]
lng2 = data2[:, 1]
diff_lat = lat1[:, None] - lat2
diff_lng = lng1[:, None] - lng2
d = (
np.sin(diff_lat / 2) ** 2
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
)
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
return data.reshape(data1.shape[0], data2.shape[0])
def distances_np(data1, data2):
# data1, data2 are the data arrays with 2 cols and they hold
# lat., lng. values in those cols respectively
data1 = np.deg2rad(data1)
data2 = np.deg2rad(data2)
lat1 = data1[:, 0]
lng1 = data1[:, 1]
lat2 = data2[:, 0]
lng2 = data2[:, 1]
diff_lat = lat1[:, None] - lat2
diff_lng = lng1[:, None] - lng2
d = (
np.sin(diff_lat / 2) ** 2
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
)
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
return data.reshape(data1.shape[0], data2.shape[0])
a = np.random.uniform(-100, 100, size=(5000, 2))
b = np.random.uniform(-100, 100, size=(5000, 2))
def dist_np_test():
return np.array(distances_np(a, b))
# enforce eager evaluation
def dist_jax_test():
return np.asarray(distances_jax(a, b))
"""
julia> a = [(rand()-0.5) * 200 for i in 1:5000, j in 1:2]; b = [(rand()-0.5) * 200 for i in 1:5000, j in 1:2];
julia> @btime distances($a, $b);
1.896 s (40 allocations: 572.36 MiB) # Float64
1.525 s (40 allocations: 286.18 MiB) # Float32
julia> @btime distances_bcast($a, $b);
1.849 s (30 allocations: 190.89 MiB) # Float64
1.452 s (30 allocations: 95.44 MiB) # Float32
julia> @btime distances_bcast_avx($a, $b);
1.202 s (43 allocations: 190.89 MiB) # Float64
743.608 ms (43 allocations: 95.44 MiB) # Float32
julia> @btime distances_threaded($a, $b); # 4 threads
476.385 ms (69 allocations: 190.89 MiB) # Float32
378.335 ms (69 allocations: 95.45 MiB) # Float32
julia> @btime distances_threaded_simd($a, $b);
337.799 ms (69 allocations: 190.89 MiB) # Float64
208.833 ms (66 allocations: 95.45 MiB) # Float32
julia> @btime distances_tullio($a, $b);
4.643 s (645 allocations: 190.91 MiB) # Float64; Probably incompatibility with Ivy Bridge arch
551.643 ms (648 allocations: 190.91 MiB) # Float64, w/o LoopVectorization
julia> @btime py"dist_np_test"();
2.444 s (39 allocations: 190.74 MiB) # Float64
julia> @btime py"dist_jax_test"(); # result is in single precision
268.703 ms (8 allocations: 320 bytes) # Float32
In conclusion:
- @DNF’s broadcasted version is more readable than the OP’s and also slightly faster
- multi-threading gives an almost perfect acceleration, as expected from the lack of any data dependencies
- SIMD gives an additional ~1.5x boost on
Float64
and close to 2x onFloat32
- single-threaded Julia solution with SIMD is ~2x faster than numpy (also single-threaded, the result may be very BLAS-dependent, my numpy uses the default OpenBLAS from OpenSUSE Tumbleweed)
- JAX in single precision is on par with multi-threaded Julia in double precision, Julia in single precision beats JAX