Make this code fast (median pooling)

I have an image say of size (1000,1000) and I want to reduce it to size (200,200) by tiling it with (5,5) batches and replacing each batch by its median. (The purpose of median is to remove some salt pepper like artifacts. I am okay to replace it with something similar robust e.g. median of medians, to speed things up.)

I wrote some optimized code, which follows the performance tips. However I could not bribe the compiler into using @simd. Is this possible to do this? Any other ideas how to speed this up.

@inline function median5_swap(a,b,c,d,e)
    # https://github.com/JeffreySarnoff/SortingNetworks.jl/blob/master/src/swapsort.jl
    a,b = minmax(a,b)
    c,d = minmax(c,d)
    a,c = minmax(a,c)
    b,d = minmax(b,d)
    c,e = minmax(e,c)
    max(c, min(e,b))
end
@inline median5(args...) = median5_swap(args...)

function medmedpool55!(out::AbstractMatrix, img::AbstractMatrix)
    @assert size(out, 1) >= size(img, 1) ÷ 5
    @assert size(out, 2) >= size(img, 2) ÷ 5
    @inbounds for j ∈ indices(out)[2]
        @simd for i ∈ indices(out)[1]
            x11 = img[5i-4, 5j-4]
            x21 = img[5i-3, 5j-4]
            x31 = img[5i-2, 5j-4]
            x41 = img[5i-1, 5j-4]
            x51 = img[5i-0, 5j-4]
            
            x12 = img[5i-4, 5j-3]
            x22 = img[5i-3, 5j-3]
            x32 = img[5i-2, 5j-3]
            x42 = img[5i-1, 5j-3]
            x52 = img[5i-0, 5j-3]
            
            x13 = img[5i-4, 5j-2]
            x23 = img[5i-3, 5j-2]
            x33 = img[5i-2, 5j-2]
            x43 = img[5i-1, 5j-2]
            x53 = img[5i-0, 5j-2]
            
            x14 = img[5i-4, 5j-1]
            x24 = img[5i-3, 5j-1]
            x34 = img[5i-2, 5j-1]
            x44 = img[5i-1, 5j-1]
            x54 = img[5i-0, 5j-1]
            
            x15 = img[5i-4, 5j-0]
            x25 = img[5i-3, 5j-0]
            x35 = img[5i-2, 5j-0]
            x45 = img[5i-1, 5j-0]
            x55 = img[5i-0, 5j-0]
            
            y1 = median5(x11,x12,x13,x14,x15)
            y2 = median5(x21,x22,x23,x24,x25)
            y3 = median5(x31,x32,x33,x34,x35)
            y4 = median5(x41,x42,x43,x44,x45)
            y5 = median5(x51,x52,x53,x54,x55)
            
            z = median5(y1,y2,y3,y4,y5)
            out[i,j] = z
        end
    end
    out
end
using BenchmarkTools
imgs = randn(Float32, 1024,1024, 10)
img = view(imgs, :,:,1)
out = similar(img, size(img) .÷ 5)
@benchmark medmedpool55!(out, img)

Probably you can take the median of 5 number slightly more quickly:

2 Likes

Not directly answering your question, and progress looks stalled, but you
might want to check out
https://github.com/JuliaImages/ImageFiltering.jl/pull/34.

Cheers,
Kevin

1 Like

Using SIMD.jl it is indeed possible to get a nice speedup. The full code is messy, but there is a small gist here.

2 Likes

Did you follow any tutorials on the use of SIMD?

I’ve used the buildin @simd from base but not the explicit SIMD package I would like to learn how to use it properly.

No I am not aware of a tutorial, but the readme of SIMD.jl explains basic usage quite well. I had some performance problems when passing (contiguous) views, for this I needed to dig in the SIMD.jl source code.