 # How to speed up `Statistics.median` along a given dimension?

I am in the process of translating a Python function into Julia. I was expecting the latter to be faster than the former, but to my surprise the Python function outperformed my Julia version each time. Using the `Profile` module, I discovered that my Julia function was spending >30% of its time computing a median value along the first dimension in a 4-D array (passed into the function as one of the arguments).

Benchmarking this bottleneck yielded some surprising (to me) results. Obviously using the stock `Statistics.median` function with the `dims` parameter is not the way to go here. How might I go about speeding up this computation, getting it down to the Python performance, or better?

Annotated Example:

First I generate a random 4D array and compute the median, iterating along the same axis in Python/Julia, and testing that the values are indeed the same.

``````julia> using PyCall, Statistics, BenchmarkTools

julia> py"""
import numpy as np
a = np.random.rand(10,10,10,10).astype(np.float32)
median_py = np.median(a, axis=3)
"""

julia> a = permutedims(py"a", 4:-1:1);

julia> median_py = permutedims(py"median_py", 3:-1:1)[[CartesianIndex()],:,:,:];

julia> median_jl = median(a, dims=1);

julia> median_py == median_jl
true

``````

Next, I run those computations again with the @benchmark macro. The Julia version takes over 4x more time to compute.

``````julia> @benchmark py"np.median(a, axis=3)"
BenchmarkTools.Trial:
memory estimate:  5.64 KiB
allocs estimate:  33
--------------
minimum time:     310.933 μs (0.00% GC)
median time:      436.488 μs (0.00% GC)
mean time:        489.510 μs (0.00% GC)
maximum time:     23.749 ms (0.00% GC)
--------------
samples:          9896
evals/sample:     1

julia> @benchmark median(a, dims=1)
BenchmarkTools.Trial:
memory estimate:  319.61 KiB
allocs estimate:  11071
--------------
minimum time:     1.352 ms (0.00% GC)
median time:      1.418 ms (0.00% GC)
mean time:        1.679 ms (1.61% GC)
maximum time:     8.195 ms (67.38% GC)
--------------
samples:          2963
evals/sample:     1
``````

Next, I recreate a “dumb” version of the `dims` argument with a comprehension, testing to make sure it yields the same result. This is almost twice as fast as the provided Julia method, but still significantly slower than Python.

``````julia> median_jl_dumb = begin
w,x,y,z = size(a)
[median(a[:,i,j,k]) for i in 1:x, j in 1:y, k in 1:z][[CartesianIndex()],:,:,:]
end;

julia> median_jl == median_jl_dumb
true

julia> @benchmark begin
w,x,y,z = size(a)
[median(a[:,i,j,k]) for i in 1:x, j in 1:y, k in 1:z][[CartesianIndex()],:,:,:]
end
BenchmarkTools.Trial:
memory estimate:  306.38 KiB
allocs estimate:  4047
--------------
minimum time:     722.054 μs (0.00% GC)
median time:      728.818 μs (0.00% GC)
mean time:        780.170 μs (1.53% GC)
maximum time:     3.396 ms (68.06% GC)
--------------
samples:          6377
evals/sample:     1
``````

Finally, I make sure there is nothing “wrong” with the Julia Statistics `median` function in general by comparing it to numpy without any axis slicing, and indeed it is comperable/faster in this case

``````julia> @benchmark py"np.median(a)"
BenchmarkTools.Trial:
memory estimate:  176 bytes
allocs estimate:  5
--------------
minimum time:     181.190 μs (0.00% GC)
median time:      186.564 μs (0.00% GC)
mean time:        199.499 μs (0.00% GC)
maximum time:     1.331 ms (0.00% GC)
--------------
samples:          10000
evals/sample:     1

julia> @benchmark median(a)
BenchmarkTools.Trial:
memory estimate:  39.16 KiB
allocs estimate:  3
--------------
minimum time:     131.974 μs (0.00% GC)
median time:      152.517 μs (0.00% GC)
mean time:        166.348 μs (3.97% GC)
maximum time:     9.239 ms (98.08% GC)
--------------
samples:          10000
evals/sample:     1

``````

`median(A, dims)` dispatches to `mapslices`, which is flexible but somewhat slow. My own profiling shows that much of the runtime is spent in an auxiliary function, `concatenate_setindex!`, which may be hobbled by broadcasting overhead. JuliennedArrays delivers a substantial speedup:

``````using JuliennedArrays

A = rand(Float32, 10, 10, 10, 10)
@btime median(\$A, dims=1)         # 1.251 ms (11070 allocations: 319.59 KiB)
@btime map(median, Slices(\$A, 1)) # 230.976 μs (2018 allocations: 192.38 KiB)

dropdims(median(A, dims=1), dims=1) == map(median, Slices(A, 1)) # true
``````

**edit: just for fun, threaded `mapslices`

``````function tmapslices(f, A, dims...)
sliced = Slices(A, dims...)
return_type = typeof(f(A[1:2]))
out = Array{return_type}(undef, size(sliced))

out[i] = f(sliced[i])
end

return out
end

@btime tmapslices(median, \$A, 1) # 142.063 μs (238 allocations: 57.70 KiB)
``````
4 Likes

Wow! I was just sitting taking median over images and was puzzled over how slow it was, for typical images sizes, the speedup is enourmous!

``````A = rand(Float32, 1000, 1000, 10)
@btime median(\$A, dims=3)         # 1.884 s (14912038 allocations: 338.17 MiB)
@btime map(median, Slices(\$A, 3)) # 283.797 ms (2000019 allocations: 186.92 MiB)
``````
3 Likes

Nice! I necro’d a related issue (https://github.com/JuliaLang/julia/issues/28431).

Fantastic, thanks a lot! It would be nice if this were eventually the default code under the hood for the `dims` keyword.

If I need those singleton dimensions remaining, is there a “cleaner” way to do this than

``````map(median, Slices(A, 1))[[CartesianIndex()],:,:,:]
``````

As the `[CartesianIndex()]` is rather ugly You can reshape the output, it does not cost anything.

1 Like

Interesting. I didn’t realize the previous syntax was allocating a new array (that’s actually kind of annoying), but I guess it makes sense.

If I want to reshape the first or last index, it is relatively trivial with say

``````reshape(a, (1, size(a)...))
``````

However, there are several times where I need to insert a singleton dimension as a middle dimension. Is there a way to make the following in one line?

``````w,x,y,z = size(a)
reshape(a, (w,x,y,1,z))
``````

(I realize I’m veering off topic here)

It ain’t pretty, but this works: `reshape(A, insert!([size(A)...], 4, 1)...)`

1 Like

Also relevant is https://github.com/JuliaLang/julia/pull/32310, which would allow `map(median, eachslice(A, dims=(1,2))` instead of `median(\$A, dims=3)`, I think.

It looks as though `eachslice` can only accept a single dimension.

``````ERROR: ArgumentError: only single dimensions are supported
Stacktrace:
 #eachslice#178 at ./abstractarraymath.jl:452 [inlined]
 (::Base.var"#kw##eachslice")(::NamedTuple{(:dims,),Tuple{Tuple{Int64,Int64,Int64}}}, ::typeof(eachslice), ::Array{Float32,4}) at ./none:0
 top-level scope at REPL:1
``````