Improving an algorithm that compute gps distances

On my machine (Ivy Bridge i7-3770, DDR3-1333 RAM), tullio version is 4x slower than @DNF’s with @avx. Maybe that’s hardware issue.

Full results:

Function definitions
using Pkg; pkg"add LoopVectorization https://github.com/mcabbott/Tullio.jl"
using Base.Threads, LoopVectorization, Tullio

function distances(data1, data2)
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    diff_lat = @view(lat1[:, None]) .- @view(lat2[None, :])
    diff_lng = @view(lng1[:, None]) .- @view(lng2[None, :])
    data = (
        @. sin(diff_lat / 2) ^ 2 +
        cos(@view(lat1[:, None])) * cos(@view(lat2[None,:])) * sin(diff_lng / 2) ^ 2
    )
    data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))
    return reshape(data, (size(data1, 1), size(data2, 1)))
end

function distances_threaded(data1, data2)
    lat1 = [deg2rad(data1[i,1]) for i in 1:size(data1,1)]
    lng1 = [deg2rad(data1[i,2]) for i in 1:size(data1,1)]
    lat2 = [deg2rad(data2[i,1]) for i in 1:size(data2,1)]
    lng2 = [deg2rad(data2[i,2]) for i in 1:size(data2,1)]
    data = Matrix{Float64}(undef, length(lat1), length(lat2))
    @threads for i in eachindex(lat2)
        lat, lng = lat2[i], lng2[i]
        data[:, i] .= @. sin((lat1 - lat) / 2) ^ 2 + cos(lat1) * cos(lat) * sin((lng1 - lng) / 2) ^ 2
    end
    @threads for i in eachindex(data)
        data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
    end
    return data
end

function distances_threaded_simd(data1, data2) # @baggepinnen
    lat1 = [deg2rad(data1[i,1]) for i in 1:size(data1,1)]
    lng1 = [deg2rad(data1[i,2]) for i in 1:size(data1,1)]
    lat2 = [deg2rad(data2[i,1]) for i in 1:size(data2,1)]
    lng2 = [deg2rad(data2[i,2]) for i in 1:size(data2,1)]
    data = Matrix{Float64}(undef, length(lat1), length(lat2))
    @threads for i in eachindex(lat2)
        lat, lng = lat2[i], lng2[i]
        @avx data[:, i] .= @. sin((lat1 - lat) / 2) ^ 2 + cos(lat1) * cos(lat) * sin((lng1 - lng) / 2) ^ 2
    end
    @threads for i in eachindex(data)
        @avx data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
    end
    return data
end

function distances_bcast(data1, data2) # @DNF
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    data = sin.((lat1 .- lat2')./2).^2 .+ cos.(lat1) .* cos.(lat2') .* sin.((lng1 .- lng2')./2).^2
    @. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
    return data
end

function distances_bcast_simd(data1, data2)
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    @avx data = sin.((lat1 .- lat2')./2).^2 .+ cos.(lat1) .* cos.(lat2') .* sin.((lng1 .- lng2')./2).^2
    @. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
    return data
end

function distances_tullio(data1deg, data2deg)
    @tullio data1[n,c] := data1deg[n,c] * (2pi/360)
    @tullio data2[n,c] := data2deg[n,c] * (2pi/360)

    @tullio data[n,m] := sin((data1[n,1] - data2[m,1])/2)^2 + 
        cos(data1[n,1]) * cos(data2[m,1]) * sin((data1[n,2] - data2[m,2])/2)^2

    @tullio data[n,m] = 2 * 6373 * atan(sqrt(abs(data[n,m])), sqrt(abs(1 - data[n,m])))
end

using PyCall
py"""
import typing as tp
from jax import numpy as jnp
import jax
import numpy as np
import time
@jax.jit
def distances_jax(data1, data2):
    # data1, data2 are the data arrays with 2 cols and they hold
    # lat., lng. values in those cols respectively
    np = jnp
    data1 = np.deg2rad(data1)
    data2 = np.deg2rad(data2)
    lat1 = data1[:, 0]
    lng1 = data1[:, 1]
    lat2 = data2[:, 0]
    lng2 = data2[:, 1]
    diff_lat = lat1[:, None] - lat2
    diff_lng = lng1[:, None] - lng2
    d = (
        np.sin(diff_lat / 2) ** 2
        + np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
    )
    data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
    return data.reshape(data1.shape[0], data2.shape[0])
def distances_np(data1, data2):
    # data1, data2 are the data arrays with 2 cols and they hold
    # lat., lng. values in those cols respectively
    data1 = np.deg2rad(data1)
    data2 = np.deg2rad(data2)
    lat1 = data1[:, 0]
    lng1 = data1[:, 1]
    lat2 = data2[:, 0]
    lng2 = data2[:, 1]
    diff_lat = lat1[:, None] - lat2
    diff_lng = lng1[:, None] - lng2
    d = (
        np.sin(diff_lat / 2) ** 2
        + np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
    )
    data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
    return data.reshape(data1.shape[0], data2.shape[0])

a = np.random.uniform(-100, 100, size=(5000, 2))
b = np.random.uniform(-100, 100, size=(5000, 2))

def dist_np_test():
    return np.array(distances_np(a, b))

# enforce eager evaluation
def dist_jax_test():
    return np.asarray(distances_jax(a, b))

"""
julia> a = [(rand()-0.5) * 200 for i in 1:5000, j in 1:2]; b = [(rand()-0.5) * 200 for i in 1:5000, j in 1:2];

julia> @btime distances($a, $b);
1.896 s (40 allocations: 572.36 MiB) # Float64
1.525 s (40 allocations: 286.18 MiB) # Float32

julia> @btime distances_bcast($a, $b);
1.849 s (30 allocations: 190.89 MiB) # Float64
1.452 s (30 allocations: 95.44 MiB) # Float32

julia> @btime distances_bcast_avx($a, $b);
1.202 s (43 allocations: 190.89 MiB) # Float64
743.608 ms (43 allocations: 95.44 MiB) # Float32

julia> @btime distances_threaded($a, $b); # 4 threads
476.385 ms (69 allocations: 190.89 MiB) # Float32
378.335 ms (69 allocations: 95.45 MiB) # Float32

julia> @btime distances_threaded_simd($a, $b);
337.799 ms (69 allocations: 190.89 MiB) # Float64
208.833 ms (66 allocations: 95.45 MiB) # Float32

julia> @btime distances_tullio($a, $b);
4.643 s (645 allocations: 190.91 MiB) # Float64; Probably incompatibility with Ivy Bridge arch
551.643 ms (648 allocations: 190.91 MiB) # Float64, w/o LoopVectorization

julia> @btime py"dist_np_test"();
2.444 s (39 allocations: 190.74 MiB) # Float64

julia> @btime py"dist_jax_test"(); # result is in single precision
268.703 ms (8 allocations: 320 bytes) # Float32

In conclusion:

  • @DNF’s broadcasted version is more readable than the OP’s and also slightly faster
  • multi-threading gives an almost perfect acceleration, as expected from the lack of any data dependencies
  • SIMD gives an additional ~1.5x boost on Float64 and close to 2x on Float32
  • single-threaded Julia solution with SIMD is ~2x faster than numpy (also single-threaded, the result may be very BLAS-dependent, my numpy uses the default OpenBLAS from OpenSUSE Tumbleweed)
  • JAX in single precision is on par with multi-threaded Julia in double precision, Julia in single precision beats JAX
2 Likes