Hi everyone! Is it possible to parallelize vectorized functions / operations? I have this code I ported from python but it running a little bit slower than numpy
and not surprisingly 10x slower than jax
(cpu only).
using Distributions
using BenchmarkTools
const None = [CartesianIndex()]
function distances(data1, data2)
data1 = deg2rad.(data1)
data2 = deg2rad.(data2)
lat1 = @view data1[:, 1]
lng1 = @view data1[:, 2]
lat2 = @view data2[:, 1]
lng2 = @view data2[:, 2]
diff_lat = @view(lat1[:, None]) .- @view(lat2[None, :])
diff_lng = @view(lng1[:, None]) .- @view(lng2[None, :])
data = (
@. sin(diff_lat / 2) ^ 2 +
cos(@view(lat1[:, None])) * cos(lat2) * sin(diff_lng / 2) ^ 2
)
data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))
return reshape(data, (size(data1, 1), size(data2, 1)))
end
The python code
# test.py
import typing as tp
from jax import numpy as jnp
import jax
import numpy as np
import time
@jax.jit
def distances_jax(data1, data2):
# data1, data2 are the data arrays with 2 cols and they hold
# lat., lng. values in those cols respectively
np = jnp
data1 = np.deg2rad(data1)
data2 = np.deg2rad(data2)
lat1 = data1[:, 0]
lng1 = data1[:, 1]
lat2 = data2[:, 0]
lng2 = data2[:, 1]
diff_lat = lat1[:, None] - lat2
diff_lng = lng1[:, None] - lng2
d = (
np.sin(diff_lat / 2) ** 2
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
)
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
return data.reshape(data1.shape[0], data2.shape[0])
def distances_np(data1, data2):
# data1, data2 are the data arrays with 2 cols and they hold
# lat., lng. values in those cols respectively
data1 = np.deg2rad(data1)
data2 = np.deg2rad(data2)
lat1 = data1[:, 0]
lng1 = data1[:, 1]
lat2 = data2[:, 0]
lng2 = data2[:, 1]
diff_lat = lat1[:, None] - lat2
diff_lng = lng1[:, None] - lng2
d = (
np.sin(diff_lat / 2) ** 2
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
)
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
return data.reshape(data1.shape[0], data2.shape[0])
a = np.random.uniform(-100, 100, size=(5000, 2))
b = np.random.uniform(-100, 100, size=(5000, 2))
t0 = time.time()
d = distances_np(a, b)
print("time np", time.time() - t0)
jnp.array([1])
t0 = time.time()
d = distances_jax(a, b)
print("time jax", time.time() - t0)
any help?