I am in the process of translating a Python function into Julia. I was expecting the latter to be faster than the former, but to my surprise the Python function outperformed my Julia version each time. Using the Profile
module, I discovered that my Julia function was spending >30% of its time computing a median value along the first dimension in a 4-D array (passed into the function as one of the arguments).
Benchmarking this bottleneck yielded some surprising (to me) results. Obviously using the stock Statistics.median
function with the dims
parameter is not the way to go here. How might I go about speeding up this computation, getting it down to the Python performance, or better?
Annotated Example:
First I generate a random 4D array and compute the median, iterating along the same axis in Python/Julia, and testing that the values are indeed the same.
julia> using PyCall, Statistics, BenchmarkTools
julia> py"""
import numpy as np
a = np.random.rand(10,10,10,10).astype(np.float32)
median_py = np.median(a, axis=3)
"""
julia> a = permutedims(py"a", 4:-1:1);
julia> median_py = permutedims(py"median_py", 3:-1:1)[[CartesianIndex()],:,:,:];
julia> median_jl = median(a, dims=1);
julia> median_py == median_jl
true
Next, I run those computations again with the @benchmark macro. The Julia version takes over 4x more time to compute.
julia> @benchmark py"np.median(a, axis=3)"
BenchmarkTools.Trial:
memory estimate: 5.64 KiB
allocs estimate: 33
--------------
minimum time: 310.933 Ī¼s (0.00% GC)
median time: 436.488 Ī¼s (0.00% GC)
mean time: 489.510 Ī¼s (0.00% GC)
maximum time: 23.749 ms (0.00% GC)
--------------
samples: 9896
evals/sample: 1
julia> @benchmark median(a, dims=1)
BenchmarkTools.Trial:
memory estimate: 319.61 KiB
allocs estimate: 11071
--------------
minimum time: 1.352 ms (0.00% GC)
median time: 1.418 ms (0.00% GC)
mean time: 1.679 ms (1.61% GC)
maximum time: 8.195 ms (67.38% GC)
--------------
samples: 2963
evals/sample: 1
Next, I recreate a ādumbā version of the dims
argument with a comprehension, testing to make sure it yields the same result. This is almost twice as fast as the provided Julia method, but still significantly slower than Python.
julia> median_jl_dumb = begin
w,x,y,z = size(a)
[median(a[:,i,j,k]) for i in 1:x, j in 1:y, k in 1:z][[CartesianIndex()],:,:,:]
end;
julia> median_jl == median_jl_dumb
true
julia> @benchmark begin
w,x,y,z = size(a)
[median(a[:,i,j,k]) for i in 1:x, j in 1:y, k in 1:z][[CartesianIndex()],:,:,:]
end
BenchmarkTools.Trial:
memory estimate: 306.38 KiB
allocs estimate: 4047
--------------
minimum time: 722.054 Ī¼s (0.00% GC)
median time: 728.818 Ī¼s (0.00% GC)
mean time: 780.170 Ī¼s (1.53% GC)
maximum time: 3.396 ms (68.06% GC)
--------------
samples: 6377
evals/sample: 1
Finally, I make sure there is nothing āwrongā with the Julia Statistics median
function in general by comparing it to numpy without any axis slicing, and indeed it is comperable/faster in this case
julia> @benchmark py"np.median(a)"
BenchmarkTools.Trial:
memory estimate: 176 bytes
allocs estimate: 5
--------------
minimum time: 181.190 Ī¼s (0.00% GC)
median time: 186.564 Ī¼s (0.00% GC)
mean time: 199.499 Ī¼s (0.00% GC)
maximum time: 1.331 ms (0.00% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark median(a)
BenchmarkTools.Trial:
memory estimate: 39.16 KiB
allocs estimate: 3
--------------
minimum time: 131.974 Ī¼s (0.00% GC)
median time: 152.517 Ī¼s (0.00% GC)
mean time: 166.348 Ī¼s (3.97% GC)
maximum time: 9.239 ms (98.08% GC)
--------------
samples: 10000
evals/sample: 1