For this PR to StatsBase.jl, I am looking for a way to compute extrema that is both efficient for regular arrays and CuArray
s.
This implementation is efficient for regular arrays:
function _compute_extrema_A(X; dims)
otherdims = dims==1 ? 2 : 1
l = size(X, otherdims)
tmin, tmax = (similar(X, l) for _ in 1:2)
for (i, x) in enumerate(eachslice(X; dims=otherdims))
tmin[i], tmax[i] = extrema(x)
end
return tmin, tmax
end
Notably, it minimizes allocations by assigning the output of extrema
directly to pre-allocated tmin, tmax
. The format of tmin,tmax
in this way is required for later use in StatsBase
.
Is there an efficient implementation for CuArray
s, i.e. without using scalar getindex
?
The following works, but is not optimal because it computes the minimum and maximum separately:
function _compute_extrema_cu(X::CuArray; dims)
tmin = vec(minimum(X; dims=dims))
tmax = vec(maximum(X; dims=dims))
return tmin, tmax
end
Preferably, the solution would use functions from Base, so that StatsBase
would not have to take on a dependency on CUDA
, or vice versa.
Code to test the above functions:
CUDA.allowscalar(false)
X = rand(Float32, 5, 8)
dims = 1
tmin, tmax = _compute_extrema_A(X; dims=dims)
tmin_cu, tmax_cu = _compute_extrema_cu(cu(X); dims=dims)
@show collect(tmin_cu)==tmin && collect(tmax_cu)==tmax