Optimizing a 2D masked mean

Hey everyone! I’m working on a small package to be used as RFI mitigation in a radio telescope, and I’m getting a little stuck optimizing a small piece of the code.

Specifically, I need to be able to take the mean down the rows and columns of a matrix but ignore positions specified by a separate boolean mask. I’m keeping the mask “positive” such that I can do essentially sum(A .* mask, dims=<x>) .... I would leave it at that, but that’s about 5x slower than mean(A,dims=x) in Statistics.

As such, I started to write my own masked_mean:

function masked_mean(x::AbstractMatrix{T}, mask, axis) where {T<:Number}
    @assert size(x) == size(mask) "The mask and input must have the same size"
    @assert axis == 1 || axis == 2 "Axis must be either 1 or 2"

    if axis == 1
        ax = axes(x, 2)
    else
        ax = axes(x, 1)
    end

    len = zeros(Int64, length(ax))
    acc = zeros(T, length(ax))

    if axis == 1
        @turbo for i in axes(x, 1), j in axes(x, 2)
            acc[j] += x[i, j] * mask[i, j]
            len[j] += mask[i, j]
        end
    else
        @turbo for i in axes(x, 1), j in axes(x, 2)
            acc[i] += x[i, j] * mask[i, j]
            len[i] += mask[i, j]
        end
    end

    @turbo @. acc / len
end

This works great for axis=1, only about 30% slower than mean, but does not work well for axis=2. For my test case, it is 4x slower than mean.

Now, my test data is a 16384x2048 float32 matrix, and I would’ve believed that in the axis=2 case, it’s 8x more data in the output, so it would make sense that it’s slower. However, this is not the case for the built-in mean; it’s practically the same speed as the axis=1 case. This kinda has me stumped, and I would appreciate any insight.

Thanks!

A general principle that applies to your problem is that Julia stores 2D arrays by stacking columns first in memory, rather than rows. See
https://docs.julialang.org/en/v1/manual/performance-tips/#man-performance-column-major

Here is another version of your function

using LinearAlgebra: dot
function new_masked_mean(X::AbstractMatrix{T}, mask; dims::Int=1) where {T<:Number}
    return [dot(r,m)/sum(m) for (r,m) in zip(eachslice(X,dims=dims), eachslice(mask,dims=dims))]
end

so the following passes:

m, n = 16384, 2048

using Random
rng = MersenneTwister(42)
R = rand(rng, Float32, m, n);
M = rand(rng, 0:1, m, n);
@assert masked_mean(R, M, 1) == new_masked_mean(R, M, dims=2)
@assert masked_mean(R, M, 2) == new_masked_mean(R, M, dims=1)

Accessing columns corresponds to dims=2 in the convention used by functions in the julia Base module. You can see this by comparing the 2nd and 3rd lines here:

A = rand(3,2)
collect(eachslice(A,dims=1))
collect(eachslice(A,dims=2))

If the boolean mask matrix has a low proportion of 1’s versus 0’s, then you might consider working directly with SparseArrays and sparsity patterns.

Thanks for your reply!

A general principle that applies to your problem is that Julia stores 2D arrays by stacking columns first in memory, rather than rows

Indeed, my question is more How does Base’s mean get around this? Your version of the masked mean is quite a bit slower and showcases why I did not use dot. One part of this is because I want to traverse both the mask and source only once, which we’re not doing because of the call to count.

Accessing columns corresponds to dims=2 in the convention used by functions in the julia Base module.

This does not seem to be the case when comparing against Statistics.mean, where dims=1 indicates we are reducing along the first axis, resulting in singleton dimension in that axis.

julia> size(R)
(16384, 2048)
julia> size(mean(R,dims=1))
(1, 2048)

Summary

using Random

rng = MersenneTwister(42)
m, n = 16384, 2048

R = rand(rng, Float32, m, n)
M = rand(rng, Bool, m, n)

Statistics.mean

julia> @btime mean($R,dims=1)
  8.975 ms (7 allocations: 8.66 KiB)
julia> @btime mean($R,dims=2)
  8.288 ms (12 allocations: 64.66 KiB)

Kiran’s masked_mean

julia> @btime masked_mean($R,$mask,1)
  10.085 ms (3 allocations: 32.38 KiB)
julia> @btime masked_mean($R,$mask,2)
  29.461 ms (6 allocations: 256.14 KiB)

James’ new_masked_mean

julia> @btime new_masked_mean($R,$mask;dims=2)
  50.825 ms (10 allocations: 8.39 KiB)
ulia> @btime new_masked_mean($R,$mask;dims=1)
  961.092 ms (11 allocations: 64.31 KiB)

I’m on the phone now so can’t contribute much but have you looked at the mean code to see what it does?

1 Like

You’re right; I shouldn’t have used eachslice as the point of comparison, since it’s explicitly slicing the complementary dimensions to dims=n.

1 Like

FWIW, this version below roughly matches Statistics.mean on dims=1 and is within a factor of 2x for dims=2:

function masked_mean2(x::AbstractMatrix{T}, mask, axis) where {T<:Number}
    @assert size(x) == size(mask) "The mask and input must have the same size"
    @assert axis == 1 || axis == 2 "Axis must be either 1 or 2"

    if axis == 1
        ax = size(x, 2)
    else
        ax = size(x, 1)
    end

    len = zeros(Int64, ax)
    acc = zeros(T, ax)

    if axis == 1
        @turbo for j in axes(x, 2)
            for i in axes(x, 1)
                acc[j] += x[i, j] * mask[i, j]
                len[j] += mask[i, j]
            end
        end
    else
        for j in axes(x, 2)
            @turbo for i in axes(x, 1)
                acc[i] += x[i, j] * mask[i, j]
                len[i] += mask[i, j]
            end
        end
    end

    @turbo @. acc / len
end

rng = MersenneTwister(42)
m, n = 16384, 2048
R = rand(rng, Float32, m, n);
mask = rand(rng, Bool, m, n);

@benchmark masked_mean($R,$mask,1)
@benchmark masked_mean2($R,$mask,1)

@benchmark masked_mean($R,$mask,2)
@benchmark masked_mean2($R,$mask,2)

Your original question still remains intriguing to me: how does Base’s mean get around the difference…

1 Like

The code is here:

Just FYI: the above line allocates an entirely new output vector, instead of updating acc in-place. You can do

@turbo acc .= acc ./ len

for an in-place version. Not that it makes a big difference, though.

(For some reason, @turbo did not accept acc ./= len.)

1 Like

Nice catch!

LoopVectorization should reorder the loops to be maximally efficient, but yeah, I definitely had them backward the first time.

The other interesting thing is Base is of course not using LoopVectorization, so how is it so quick? The code in Base doesn’t seem to be doing anything particularly clever.

Yeah so it seems sum is really the fast part here. Lots to dig in to…

If I try to just use sum, I can keep around a temp matrix for masked calculations (as I’m going to be doing this and variance a lot), so I can write:

function mean_ax_3!(tmp,R,mask,dim)
           @. tmp = R * mask
           mean(tmp,dims=dim)
end

Which has the nice added benefit that it “just works” on the GPU with CUDA.jl (about a factor of 10 faster)

1 Like