Efficient and CuArray compatible implementation of `StatsBase._compute_extrema`

For this PR to StatsBase.jl, I am looking for a way to compute extrema that is both efficient for regular arrays and CuArrays.

This implementation is efficient for regular arrays:

function _compute_extrema_A(X; dims)
    otherdims = dims==1 ? 2 : 1
    l = size(X, otherdims)
    tmin, tmax = (similar(X, l) for _ in 1:2)
    for (i, x) in enumerate(eachslice(X; dims=otherdims))
        tmin[i], tmax[i] = extrema(x)
    end
    return tmin, tmax
end

Notably, it minimizes allocations by assigning the output of extrema directly to pre-allocated tmin, tmax. The format of tmin,tmax in this way is required for later use in StatsBase.

Is there an efficient implementation for CuArrays, i.e. without using scalar getindex?

The following works, but is not optimal because it computes the minimum and maximum separately:

function _compute_extrema_cu(X::CuArray; dims)
    tmin = vec(minimum(X; dims=dims))
    tmax = vec(maximum(X; dims=dims))
    return tmin, tmax
end

Preferably, the solution would use functions from Base, so that StatsBase would not have to take on a dependency on CUDA, or vice versa.

Code to test the above functions:

CUDA.allowscalar(false)
X = rand(Float32, 5, 8)
dims = 1

tmin, tmax = _compute_extrema_A(X; dims=dims)

tmin_cu, tmax_cu = _compute_extrema_cu(cu(X); dims=dims)
@show collect(tmin_cu)==tmin && collect(tmax_cu)==tmax

What about:

julia> mapreduce(x->(x,x), (a,b)->(min(a[1],b[1]),max(a[2],b[2])), arr; init=(typemax(eltype(arr)), typemin(eltype(arr))))
(0.0015687909f0, 0.9998264f0)

julia> minimum(arr)
0.0015687909f0

julia> maximum(arr)
0.9998264f0

If that works, and is sufficient to implement Base.extrema, maybe open a PR to GPUArrays to do so? That way you can just use the extrema function (with a dims argument).

The problem is that when passing dims we want to get an array with the minimum for each slice, and an array with the maximum for each slice, rather than an array of (minimum, maximum) tuples.

Can’t you reinterpret and view? Or, if you want dense arrays, do a strided memcpy?

For example:

julia> arr = CUDA.rand(3,4,5);

julia> minmax = mapreduce(x->(x,x), (a,b)->(min(a[1],b[1]),max(a[2],b[2])), arr; init=(typemax(eltype(arr)), typemin(eltype(arr))), dims=1)
1×4×5 CuArray{Tuple{Float32, Float32}, 3}:
[:, :, 1] =
 (0.407247, 0.906906)  (0.225997, 0.790181)  (0.362782, 0.7733)  (0.368122, 0.550295)

[:, :, 2] =
 (0.179379, 0.43576)  (0.28129, 0.769908)  (0.456139, 0.831805)  (0.264724, 0.976437)

[:, :, 3] =
 (0.0381522, 0.984508)  (0.098552, 0.891221)  (0.494028, 0.996213)  (0.0941686, 0.899043)

[:, :, 4] =
 (0.57849, 0.961374)  (0.0394041, 0.178669)  (0.0809094, 0.877981)  (0.0169927, 0.875119)

[:, :, 5] =
 (0.0675674, 0.440386)  (0.0870846, 0.791125)  (0.700442, 0.856237)  (0.395846, 0.729174)

julia> minmax = reinterpret(Float32, minmax);

julia> minima = similar(minmax, size(minmax)[2:end]);

julia> Mem.unsafe_copy2d!(pointer(minima), Mem.Device, pointer(minmax), Mem.Device, 1, length(minima); srcPitch=2*sizeof(eltype(minmax)), dstPitch=0)

julia> reshape(minima, (1,4,5))
1×4×5 CuArray{Float32, 3}:
[:, :, 1] =
 0.407247  0.225997  0.362782  0.368122

[:, :, 2] =
 0.179379  0.28129  0.456139  0.264724

[:, :, 3] =
 0.0381522  0.098552  0.494028  0.0941686

[:, :, 4] =
 0.57849  0.0394041  0.0809094  0.0169927

[:, :, 5] =
 0.0675674  0.0870846  0.700442  0.395846

Ah, yes, reinterpret with a view would work. Though that would only support bitstype, so we would need two different paths. Unfortunately that would make the handling of all arrays (including Array) more complex, for a benefit that would only affect CuArray. Life is hard…

Thanks for your thoughts and ideas. To keep things simple, I propose to just merge the current PR; there are many benefits of that even without a resolution of this problem.

Once extrema is implemented for CUDA, we can perphaps write a near-optimal version for _compute_extrema for CuArray. A user who is looking for the last bit of performance will have to write a separate type that has a different organization of tmin and tmax. This is not too hard to write.