Improving an algorithm that compute gps distances

Hi everyone! Is it possible to parallelize vectorized functions / operations? I have this code I ported from python but it running a little bit slower than numpy and not surprisingly 10x slower than jax (cpu only).

using Distributions
using BenchmarkTools
const None = [CartesianIndex()] 
function distances(data1, data2)
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    diff_lat = @view(lat1[:, None]) .- @view(lat2[None, :])
    diff_lng = @view(lng1[:, None]) .- @view(lng2[None, :])
    data = (
        @. sin(diff_lat / 2) ^ 2 +
        cos(@view(lat1[:, None])) * cos(lat2) * sin(diff_lng / 2) ^ 2
    )
    data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))
    return reshape(data, (size(data1, 1), size(data2, 1)))
end

The python code

# test.py
import typing as tp
from jax import numpy as jnp
import jax
import numpy as np
import time
@jax.jit
def distances_jax(data1, data2):
    # data1, data2 are the data arrays with 2 cols and they hold
    # lat., lng. values in those cols respectively
    np = jnp
    data1 = np.deg2rad(data1)
    data2 = np.deg2rad(data2)
    lat1 = data1[:, 0]
    lng1 = data1[:, 1]
    lat2 = data2[:, 0]
    lng2 = data2[:, 1]
    diff_lat = lat1[:, None] - lat2
    diff_lng = lng1[:, None] - lng2
    d = (
        np.sin(diff_lat / 2) ** 2
        + np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
    )
    data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
    return data.reshape(data1.shape[0], data2.shape[0])
def distances_np(data1, data2):
    # data1, data2 are the data arrays with 2 cols and they hold
    # lat., lng. values in those cols respectively
    data1 = np.deg2rad(data1)
    data2 = np.deg2rad(data2)
    lat1 = data1[:, 0]
    lng1 = data1[:, 1]
    lat2 = data2[:, 0]
    lng2 = data2[:, 1]
    diff_lat = lat1[:, None] - lat2
    diff_lng = lng1[:, None] - lng2
    d = (
        np.sin(diff_lat / 2) ** 2
        + np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
    )
    data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
    return data.reshape(data1.shape[0], data2.shape[0])
a = np.random.uniform(-100, 100, size=(5000, 2))
b = np.random.uniform(-100, 100, size=(5000, 2))
t0 = time.time()
d = distances_np(a, b)
print("time np", time.time() - t0)
jnp.array([1])
t0 = time.time()
d = distances_jax(a, b)
print("time jax", time.time() - t0)

any help?

1 Like

Hello and welcome!

Is the procedure supposed to compute all pairwise distances between points?
If so, then there’s a bug in Julia code, as the matrix for distances(a, a) is not symmetric (needs cos(@view(lat2[None, :])) when computing d?)

Regardless of that, you have to be a bit more explicit to use multi-threaded computations in Julia:

using Base.Threads
function distances_threaded(data1, data2)
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    data = Matrix{Float64}(undef, length(lat1), length(lat2))
    @threads for i in eachindex(lat2)
        lat2i, lng2i = lat2[i], lng2[i]
        data[:, i] .= @. sin((lat1 - lat2i) / 2) ^ 2 + cos(lat1) * cos(lat2i) * sin((lng1 - lng2i) / 2) ^ 2
    end
    @threads for i in eachindex(data)
        data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
    end
    return data
end

(also, you have to explicitly start Julia with JULIA_NUM_THREADS=⟨N⟩ to get the acceleration from @threads)

using LoopVectorization shaves off quite some time here

function distances_avx(data1, data2)
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    data = Matrix{Float64}(undef, length(lat1), length(lat2))
    for i in eachindex(lat2)
        lat2i, lng2i = lat2[i], lng2[i]
        @avx data[:, i] .= @. sin((lat1 - lat2i) / 2) ^ 2 + cos(lat1) * cos(lat2i) * sin((lng1 - lng2i) / 2) ^ 2
    end
    @avx for i in eachindex(data)
        data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
    end
    return data
end

a = rand(Uniform(-100, 100), 5000, 2)
b = rand(Uniform(-100, 100), 5000, 2)

@time distances(a,b);     # 2.215078 seconds (5.01 k allocations: 191.117 MiB)
@time distances_avx(a,b); # 0.839410 seconds (6 allocations: 190.888 MiB)

@test distances_avx(a,b) ≈ distances(a,b) # passes

The function still allocates a lot so reducing those allocations would likely increase performance quite a lot

EDIT: Making it type stable and passing it Float32 reduces the avx version timing to

227.355 ms (6 allocations: 95.44 MiB)

without threading

6 Likes

This post is saying that JAX is like ~300x faster than Julia

I am a little skeptial, but the code is here

I wonder if the author has exhausted Julia optimizations?

Possibly the opposite; the code has globals (None = ...).

1 Like

Julia’s speed features seems to elude new comers.

Especially those who fail to read the manual.

It never ceases to amaze me how people expect to use a complex tool without some initial investment. But apparently it is a common phenomenon: the other day I was reading an interview with a trauma surgeon on the rise of injuries from people trying to do some home improvement with power tools during social distancing, with predictable results.

Read the manual or lose your fingers. The same applies to Julia, metaphorically.

6 Likes

Well, but it is not so difficult.

First, the source code (all code, not only the function to benchmark) should be inside a function, to avoid unneeded global variables, it is simple.

The version without threads is not good code (global None, not reserve memory, …) the thread version is copy-paste of proposed by @Vasily_Pisarev but in my opinion the non-thread version should be the same without the threads, not maintaining the old version with the performance errors.

Also, the example is all vectorised, in real-world problem many times it is difficult to put all code in that way, in these cases in when Julia brises, because there is not restrictions. I knew that jax could improve a lot using the numpy API, but I did not know the jax.jit, it is great, but I am nont sure if it could be also be good in not vectorised operations.

Jax is great (and Flax), but it is not yet to newcomers neither. It is a great tool, but I still prefer Julia.
Also, LoopVectorization results are also great, thanks @baggepinnen.

1 Like

The only global in the OP code is a constant, aren’t they supposed to be free wrt performance?

1 Like

No, julia does not know that is a constant, you should explicitly say it:

const None = ...

Anyway, it is simple to put it inside the function because it is only used inside one function.

Also, a and b are also globals, so to benchmark you should put:

@btime distances_...($a, $b)

to avoid the performance problem when you are benchmarking the code.
Anyway, the thread version is better code.

It should be possible to reduce the computational time quite a bit though

julia> a = randn(Float32, 10_000);
julia> @btime $a .|> SLEEFPirates.sin_fast .|> SLEEFPirates.cos_fast .|> SLEEFPirates.atan .|> abs .|> sqrt .|> abs2;
  680.251 μs (2 allocations: 39.14 KiB)

so if memory access is not becoming a problem, speeds similar to the quoted jax speed should be feasible

1 Like

I often find that life’s too short not to read the manual. If I’m getting a new frying pan I might skip it (might still look at any tasty pictures), if I’m adopting a new software tool, I’d better RTFM.

2 Likes

All this None stuff just messes up the code. Just rely on broadcasting:

function distances_(data1, data2)
    data1 .= deg2rad.(data1)
    data2 .= deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    diff_lat = lat1 .- lat2'
    diff_lng = lng1 .- lng2'
    data = @. sin(diff_lat/2) ^ 2 + cos(lat1) * cos(lat2) * sin(diff_lng/2)^2
    @. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
    return reshape(data, size(data1, 1), :)
end

Can someone confirm the jax benchmark? It’s three orders of magnitude faster than Julia apparently. Really?

3 Likes

That’s why I don’t believe it

1 Like

Turns out the Jax code computes the result lazily, so the function returns immediately, and then the matrix is computed in the background. Any operation that reads from the matrix then simply waits for the computation to be done.

This is actually very impressive of Jax, but on my computer it takes 380 ms for Jax to compute the matrix. This is the real timing we should compare against.

7 Likes

OK, that makes more sense! Some Julia times:

julia> a = -100 .+ 200 .* rand(5000, 2);
julia> b = -100 .+ 200 .* rand(5000, 2); # Float64

julia> @btime distances_1($a, $b);  # @DNF's function, but not mutating a, b
  1.361 s (14 allocations: 572.36 MiB)

julia> @btime distances_2($a, $b); # line-for-line conversion
  175.966 ms (1581 allocations: 572.44 MiB)

julia> @btime distances_3($a, $b); # avoiding allocating diff_lat etc
  70.865 ms (794 allocations: 190.93 MiB)

julia> a2, b2, r2 = copy(a), copy(b), copy(res); # caches

julia> @btime distances_4($a, $b, $a2, $b2, $r2); # in-place version
  50.386 ms (787 allocations: 40.55 KiB)

julia> @btime distances_avx($a,$b); # from @baggepinnen above
  391.021 ms (6 allocations: 190.89 MiB)

This is multi-threaded (as is Jax, I presume), and with some LoopVectorization magic. On my (2-core) laptop, distances_3 takes 180.106 ms instead.

Edit: with code, after tidying up a little… this relies on a package which is a bit WIP:

a = -100 .+ 200 .* rand(5000, 2);
b = -100 .+ 200 .* rand(5000, 2);
res = distances_1(a, b);

using Pkg; pkg"add LoopVectorization https://github.com/mcabbott/Tullio.jl"
using LoopVectorization, Tullio

function distances_3(data1deg, data2deg)
    @tullio data1[n,c] := data1deg[n,c] * (2pi/360)
    @tullio data2[n,c] := data2deg[n,c] * (2pi/360)

    @tullio data[n,m] := sin((data1[n,1] - data2[m,1])/2)^2 + 
        cos(data1[n,1]) * cos(data2[n,1]) * sin((data1[n,2] - data2[m,2])/2)^2

    @tullio data[n,m] = 2 * 6373 * atan(sqrt(abs(data[n,m])), sqrt(abs(1 - data[n,m])))
end

res ≈ distances_3(a, b) # true

@btime distances_1($a, $b); 
@btime distances_3($a, $b);

1 Like

Could you please post the code (or a gist link) here for reference? The comments just aren’t enough to reproduce the results.

1 Like

Done!

Would be curious to see a comparison to Jax on the same machine.

1 Like

On my machine (Ivy Bridge i7-3770, DDR3-1333 RAM), tullio version is 4x slower than @DNF’s with @avx. Maybe that’s hardware issue.

Full results:

Function definitions
using Pkg; pkg"add LoopVectorization https://github.com/mcabbott/Tullio.jl"
using Base.Threads, LoopVectorization, Tullio

function distances(data1, data2)
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    diff_lat = @view(lat1[:, None]) .- @view(lat2[None, :])
    diff_lng = @view(lng1[:, None]) .- @view(lng2[None, :])
    data = (
        @. sin(diff_lat / 2) ^ 2 +
        cos(@view(lat1[:, None])) * cos(@view(lat2[None,:])) * sin(diff_lng / 2) ^ 2
    )
    data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))
    return reshape(data, (size(data1, 1), size(data2, 1)))
end

function distances_threaded(data1, data2)
    lat1 = [deg2rad(data1[i,1]) for i in 1:size(data1,1)]
    lng1 = [deg2rad(data1[i,2]) for i in 1:size(data1,1)]
    lat2 = [deg2rad(data2[i,1]) for i in 1:size(data2,1)]
    lng2 = [deg2rad(data2[i,2]) for i in 1:size(data2,1)]
    data = Matrix{Float64}(undef, length(lat1), length(lat2))
    @threads for i in eachindex(lat2)
        lat, lng = lat2[i], lng2[i]
        data[:, i] .= @. sin((lat1 - lat) / 2) ^ 2 + cos(lat1) * cos(lat) * sin((lng1 - lng) / 2) ^ 2
    end
    @threads for i in eachindex(data)
        data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
    end
    return data
end

function distances_threaded_simd(data1, data2) # @baggepinnen
    lat1 = [deg2rad(data1[i,1]) for i in 1:size(data1,1)]
    lng1 = [deg2rad(data1[i,2]) for i in 1:size(data1,1)]
    lat2 = [deg2rad(data2[i,1]) for i in 1:size(data2,1)]
    lng2 = [deg2rad(data2[i,2]) for i in 1:size(data2,1)]
    data = Matrix{Float64}(undef, length(lat1), length(lat2))
    @threads for i in eachindex(lat2)
        lat, lng = lat2[i], lng2[i]
        @avx data[:, i] .= @. sin((lat1 - lat) / 2) ^ 2 + cos(lat1) * cos(lat) * sin((lng1 - lng) / 2) ^ 2
    end
    @threads for i in eachindex(data)
        @avx data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
    end
    return data
end

function distances_bcast(data1, data2) # @DNF
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    data = sin.((lat1 .- lat2')./2).^2 .+ cos.(lat1) .* cos.(lat2') .* sin.((lng1 .- lng2')./2).^2
    @. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
    return data
end

function distances_bcast_simd(data1, data2)
    data1 = deg2rad.(data1)
    data2 = deg2rad.(data2)
    lat1 = @view data1[:, 1]
    lng1 = @view data1[:, 2]
    lat2 = @view data2[:, 1]
    lng2 = @view data2[:, 2]
    @avx data = sin.((lat1 .- lat2')./2).^2 .+ cos.(lat1) .* cos.(lat2') .* sin.((lng1 .- lng2')./2).^2
    @. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
    return data
end

function distances_tullio(data1deg, data2deg)
    @tullio data1[n,c] := data1deg[n,c] * (2pi/360)
    @tullio data2[n,c] := data2deg[n,c] * (2pi/360)

    @tullio data[n,m] := sin((data1[n,1] - data2[m,1])/2)^2 + 
        cos(data1[n,1]) * cos(data2[m,1]) * sin((data1[n,2] - data2[m,2])/2)^2

    @tullio data[n,m] = 2 * 6373 * atan(sqrt(abs(data[n,m])), sqrt(abs(1 - data[n,m])))
end

using PyCall
py"""
import typing as tp
from jax import numpy as jnp
import jax
import numpy as np
import time
@jax.jit
def distances_jax(data1, data2):
    # data1, data2 are the data arrays with 2 cols and they hold
    # lat., lng. values in those cols respectively
    np = jnp
    data1 = np.deg2rad(data1)
    data2 = np.deg2rad(data2)
    lat1 = data1[:, 0]
    lng1 = data1[:, 1]
    lat2 = data2[:, 0]
    lng2 = data2[:, 1]
    diff_lat = lat1[:, None] - lat2
    diff_lng = lng1[:, None] - lng2
    d = (
        np.sin(diff_lat / 2) ** 2
        + np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
    )
    data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
    return data.reshape(data1.shape[0], data2.shape[0])
def distances_np(data1, data2):
    # data1, data2 are the data arrays with 2 cols and they hold
    # lat., lng. values in those cols respectively
    data1 = np.deg2rad(data1)
    data2 = np.deg2rad(data2)
    lat1 = data1[:, 0]
    lng1 = data1[:, 1]
    lat2 = data2[:, 0]
    lng2 = data2[:, 1]
    diff_lat = lat1[:, None] - lat2
    diff_lng = lng1[:, None] - lng2
    d = (
        np.sin(diff_lat / 2) ** 2
        + np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
    )
    data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
    return data.reshape(data1.shape[0], data2.shape[0])

a = np.random.uniform(-100, 100, size=(5000, 2))
b = np.random.uniform(-100, 100, size=(5000, 2))

def dist_np_test():
    return np.array(distances_np(a, b))

# enforce eager evaluation
def dist_jax_test():
    return np.asarray(distances_jax(a, b))

"""
julia> a = [(rand()-0.5) * 200 for i in 1:5000, j in 1:2]; b = [(rand()-0.5) * 200 for i in 1:5000, j in 1:2];

julia> @btime distances($a, $b);
1.896 s (40 allocations: 572.36 MiB) # Float64
1.525 s (40 allocations: 286.18 MiB) # Float32

julia> @btime distances_bcast($a, $b);
1.849 s (30 allocations: 190.89 MiB) # Float64
1.452 s (30 allocations: 95.44 MiB) # Float32

julia> @btime distances_bcast_avx($a, $b);
1.202 s (43 allocations: 190.89 MiB) # Float64
743.608 ms (43 allocations: 95.44 MiB) # Float32

julia> @btime distances_threaded($a, $b); # 4 threads
476.385 ms (69 allocations: 190.89 MiB) # Float32
378.335 ms (69 allocations: 95.45 MiB) # Float32

julia> @btime distances_threaded_simd($a, $b);
337.799 ms (69 allocations: 190.89 MiB) # Float64
208.833 ms (66 allocations: 95.45 MiB) # Float32

julia> @btime distances_tullio($a, $b);
4.643 s (645 allocations: 190.91 MiB) # Float64; Probably incompatibility with Ivy Bridge arch
551.643 ms (648 allocations: 190.91 MiB) # Float64, w/o LoopVectorization

julia> @btime py"dist_np_test"();
2.444 s (39 allocations: 190.74 MiB) # Float64

julia> @btime py"dist_jax_test"(); # result is in single precision
268.703 ms (8 allocations: 320 bytes) # Float32

In conclusion:

  • @DNF’s broadcasted version is more readable than the OP’s and also slightly faster
  • multi-threading gives an almost perfect acceleration, as expected from the lack of any data dependencies
  • SIMD gives an additional ~1.5x boost on Float64 and close to 2x on Float32
  • single-threaded Julia solution with SIMD is ~2x faster than numpy (also single-threaded, the result may be very BLAS-dependent, my numpy uses the default OpenBLAS from OpenSUSE Tumbleweed)
  • JAX in single precision is on par with multi-threaded Julia in double precision, Julia in single precision beats JAX
2 Likes