How to speed up `Statistics.median` along a given dimension?

I am in the process of translating a Python function into Julia. I was expecting the latter to be faster than the former, but to my surprise the Python function outperformed my Julia version each time. Using the Profile module, I discovered that my Julia function was spending >30% of its time computing a median value along the first dimension in a 4-D array (passed into the function as one of the arguments).

Benchmarking this bottleneck yielded some surprising (to me) results. Obviously using the stock Statistics.median function with the dims parameter is not the way to go here. How might I go about speeding up this computation, getting it down to the Python performance, or better?

Annotated Example:

First I generate a random 4D array and compute the median, iterating along the same axis in Python/Julia, and testing that the values are indeed the same.

julia> using PyCall, Statistics, BenchmarkTools

julia> py"""
       import numpy as np
       a = np.random.rand(10,10,10,10).astype(np.float32)
       median_py = np.median(a, axis=3)
       """

julia> a = permutedims(py"a", 4:-1:1);

julia> median_py = permutedims(py"median_py", 3:-1:1)[[CartesianIndex()],:,:,:];

julia> median_jl = median(a, dims=1);

julia> median_py == median_jl
true

Next, I run those computations again with the @benchmark macro. The Julia version takes over 4x more time to compute.

julia> @benchmark py"np.median(a, axis=3)"
BenchmarkTools.Trial: 
  memory estimate:  5.64 KiB
  allocs estimate:  33
  --------------
  minimum time:     310.933 μs (0.00% GC)
  median time:      436.488 μs (0.00% GC)
  mean time:        489.510 μs (0.00% GC)
  maximum time:     23.749 ms (0.00% GC)
  --------------
  samples:          9896
  evals/sample:     1

julia> @benchmark median(a, dims=1)
BenchmarkTools.Trial: 
  memory estimate:  319.61 KiB
  allocs estimate:  11071
  --------------
  minimum time:     1.352 ms (0.00% GC)
  median time:      1.418 ms (0.00% GC)
  mean time:        1.679 ms (1.61% GC)
  maximum time:     8.195 ms (67.38% GC)
  --------------
  samples:          2963
  evals/sample:     1

Next, I recreate a “dumb” version of the dims argument with a comprehension, testing to make sure it yields the same result. This is almost twice as fast as the provided Julia method, but still significantly slower than Python.

julia> median_jl_dumb = begin
           w,x,y,z = size(a)
           [median(a[:,i,j,k]) for i in 1:x, j in 1:y, k in 1:z][[CartesianIndex()],:,:,:]
       end;

julia> median_jl == median_jl_dumb
true

julia> @benchmark begin
           w,x,y,z = size(a)
           [median(a[:,i,j,k]) for i in 1:x, j in 1:y, k in 1:z][[CartesianIndex()],:,:,:]
       end
BenchmarkTools.Trial: 
  memory estimate:  306.38 KiB
  allocs estimate:  4047
  --------------
  minimum time:     722.054 μs (0.00% GC)
  median time:      728.818 μs (0.00% GC)
  mean time:        780.170 μs (1.53% GC)
  maximum time:     3.396 ms (68.06% GC)
  --------------
  samples:          6377
  evals/sample:     1

Finally, I make sure there is nothing “wrong” with the Julia Statistics median function in general by comparing it to numpy without any axis slicing, and indeed it is comperable/faster in this case

julia> @benchmark py"np.median(a)"
BenchmarkTools.Trial: 
  memory estimate:  176 bytes
  allocs estimate:  5
  --------------
  minimum time:     181.190 μs (0.00% GC)
  median time:      186.564 μs (0.00% GC)
  mean time:        199.499 μs (0.00% GC)
  maximum time:     1.331 ms (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark median(a)
BenchmarkTools.Trial: 
  memory estimate:  39.16 KiB
  allocs estimate:  3
  --------------
  minimum time:     131.974 μs (0.00% GC)
  median time:      152.517 μs (0.00% GC)
  mean time:        166.348 μs (3.97% GC)
  maximum time:     9.239 ms (98.08% GC)
  --------------
  samples:          10000
  evals/sample:     1

median(A, dims) dispatches to mapslices, which is flexible but somewhat slow. My own profiling shows that much of the runtime is spent in an auxiliary function, concatenate_setindex!, which may be hobbled by broadcasting overhead. JuliennedArrays delivers a substantial speedup:

using JuliennedArrays

A = rand(Float32, 10, 10, 10, 10)
@btime median($A, dims=1)         # 1.251 ms (11070 allocations: 319.59 KiB)
@btime map(median, Slices($A, 1)) # 230.976 μs (2018 allocations: 192.38 KiB)

dropdims(median(A, dims=1), dims=1) == map(median, Slices(A, 1)) # true

**edit: just for fun, threaded mapslices

function tmapslices(f, A, dims...)
    sliced = Slices(A, dims...)
    return_type = typeof(f(A[1:2]))
    out = Array{return_type}(undef, size(sliced))

    Threads.@threads for i in eachindex(sliced)
        out[i] = f(sliced[i])
    end
    
    return out
end

@btime tmapslices(median, $A, 1) # 142.063 μs (238 allocations: 57.70 KiB)
4 Likes

Wow! I was just sitting taking median over images and was puzzled over how slow it was, for typical images sizes, the speedup is enourmous!

A = rand(Float32, 1000, 1000, 10)
@btime median($A, dims=3)         # 1.884 s (14912038 allocations: 338.17 MiB)
@btime map(median, Slices($A, 3)) # 283.797 ms (2000019 allocations: 186.92 MiB)
3 Likes

Nice! I necro’d a related issue (https://github.com/JuliaLang/julia/issues/28431).

Fantastic, thanks a lot! It would be nice if this were eventually the default code under the hood for the dims keyword.

If I need those singleton dimensions remaining, is there a “cleaner” way to do this than

map(median, Slices(A, 1))[[CartesianIndex()],:,:,:]

As the [CartesianIndex()] is rather ugly :wink:

You can reshape the output, it does not cost anything.

1 Like

Interesting. I didn’t realize the previous syntax was allocating a new array (that’s actually kind of annoying), but I guess it makes sense.

If I want to reshape the first or last index, it is relatively trivial with say

reshape(a, (1, size(a)...))

However, there are several times where I need to insert a singleton dimension as a middle dimension. Is there a way to make the following in one line?

w,x,y,z = size(a)
reshape(a, (w,x,y,1,z))

(I realize I’m veering off topic here)

It ain’t pretty, but this works: reshape(A, insert!([size(A)...], 4, 1)...)

1 Like

Also relevant is https://github.com/JuliaLang/julia/pull/32310, which would allow map(median, eachslice(A, dims=(1,2)) instead of median($A, dims=3), I think.

It looks as though eachslice can only accept a single dimension.

ERROR: ArgumentError: only single dimensions are supported
Stacktrace:
 [1] #eachslice#178 at ./abstractarraymath.jl:452 [inlined]
 [2] (::Base.var"#kw##eachslice")(::NamedTuple{(:dims,),Tuple{Tuple{Int64,Int64,Int64}}}, ::typeof(eachslice), ::Array{Float32,4}) at ./none:0
 [3] top-level scope at REPL[15]:1