For this PR to StatsBase.jl, I am looking for a way to compute extrema that is both efficient for regular arrays and CuArrays.
This implementation is efficient for regular arrays:
function _compute_extrema_A(X; dims)
otherdims = dims==1 ? 2 : 1
l = size(X, otherdims)
tmin, tmax = (similar(X, l) for _ in 1:2)
for (i, x) in enumerate(eachslice(X; dims=otherdims))
tmin[i], tmax[i] = extrema(x)
end
return tmin, tmax
end
Notably, it minimizes allocations by assigning the output of extrema directly to pre-allocated tmin, tmax. The format of tmin,tmax in this way is required for later use in StatsBase.
Is there an efficient implementation for CuArrays, i.e. without using scalar getindex?
The following works, but is not optimal because it computes the minimum and maximum separately:
function _compute_extrema_cu(X::CuArray; dims)
tmin = vec(minimum(X; dims=dims))
tmax = vec(maximum(X; dims=dims))
return tmin, tmax
end
Preferably, the solution would use functions from Base, so that StatsBase would not have to take on a dependency on CUDA, or vice versa.
Code to test the above functions:
CUDA.allowscalar(false)
X = rand(Float32, 5, 8)
dims = 1
tmin, tmax = _compute_extrema_A(X; dims=dims)
tmin_cu, tmax_cu = _compute_extrema_cu(cu(X); dims=dims)
@show collect(tmin_cu)==tmin && collect(tmax_cu)==tmax