Hey everyone! I’m working on a small package to be used as RFI mitigation in a radio telescope, and I’m getting a little stuck optimizing a small piece of the code.
Specifically, I need to be able to take the mean down the rows and columns of a matrix but ignore positions specified by a separate boolean mask. I’m keeping the mask “positive” such that I can do essentially sum(A .* mask, dims=<x>) ...
. I would leave it at that, but that’s about 5x slower than mean(A,dims=x)
in Statistics.
As such, I started to write my own masked_mean
:
function masked_mean(x::AbstractMatrix{T}, mask, axis) where {T<:Number}
@assert size(x) == size(mask) "The mask and input must have the same size"
@assert axis == 1 || axis == 2 "Axis must be either 1 or 2"
if axis == 1
ax = axes(x, 2)
else
ax = axes(x, 1)
end
len = zeros(Int64, length(ax))
acc = zeros(T, length(ax))
if axis == 1
@turbo for i in axes(x, 1), j in axes(x, 2)
acc[j] += x[i, j] * mask[i, j]
len[j] += mask[i, j]
end
else
@turbo for i in axes(x, 1), j in axes(x, 2)
acc[i] += x[i, j] * mask[i, j]
len[i] += mask[i, j]
end
end
@turbo @. acc / len
end
This works great for axis=1
, only about 30% slower than mean
, but does not work well for axis=2
. For my test case, it is 4x slower than mean
.
Now, my test data is a 16384x2048 float32 matrix, and I would’ve believed that in the axis=2
case, it’s 8x more data in the output, so it would make sense that it’s slower. However, this is not the case for the built-in mean
; it’s practically the same speed as the axis=1
case. This kinda has me stumped, and I would appreciate any insight.
Thanks!