 # Efficient and CuArray compatible implementation of `StatsBase._compute_extrema`

For this PR to StatsBase.jl, I am looking for a way to compute extrema that is both efficient for regular arrays and `CuArray`s.

This implementation is efficient for regular arrays:

``````function _compute_extrema_A(X; dims)
otherdims = dims==1 ? 2 : 1
l = size(X, otherdims)
tmin, tmax = (similar(X, l) for _ in 1:2)
for (i, x) in enumerate(eachslice(X; dims=otherdims))
tmin[i], tmax[i] = extrema(x)
end
return tmin, tmax
end
``````

Notably, it minimizes allocations by assigning the output of `extrema` directly to pre-allocated `tmin, tmax`. The format of `tmin,tmax` in this way is required for later use in `StatsBase`.

Is there an efficient implementation for `CuArray`s, i.e. without using scalar `getindex`?

The following works, but is not optimal because it computes the minimum and maximum separately:

``````function _compute_extrema_cu(X::CuArray; dims)
tmin = vec(minimum(X; dims=dims))
tmax = vec(maximum(X; dims=dims))
return tmin, tmax
end
``````

Preferably, the solution would use functions from Base, so that `StatsBase` would not have to take on a dependency on `CUDA`, or vice versa.

Code to test the above functions:

``````CUDA.allowscalar(false)
X = rand(Float32, 5, 8)
dims = 1

tmin, tmax = _compute_extrema_A(X; dims=dims)

tmin_cu, tmax_cu = _compute_extrema_cu(cu(X); dims=dims)
@show collect(tmin_cu)==tmin && collect(tmax_cu)==tmax
``````

``````julia> mapreduce(x->(x,x), (a,b)->(min(a,b),max(a,b)), arr; init=(typemax(eltype(arr)), typemin(eltype(arr))))
(0.0015687909f0, 0.9998264f0)

julia> minimum(arr)
0.0015687909f0

julia> maximum(arr)
0.9998264f0
``````

If that works, and is sufficient to implement `Base.extrema`, maybe open a PR to GPUArrays to do so? That way you can just use the `extrema` function (with a dims argument).

The problem is that when passing `dims` we want to get an array with the minimum for each slice, and an array with the maximum for each slice, rather than an array of `(minimum, maximum)` tuples.

Can’t you `reinterpret` and `view`? Or, if you want dense arrays, do a strided memcpy?

For example:

``````julia> arr = CUDA.rand(3,4,5);

julia> minmax = mapreduce(x->(x,x), (a,b)->(min(a,b),max(a,b)), arr; init=(typemax(eltype(arr)), typemin(eltype(arr))), dims=1)
1×4×5 CuArray{Tuple{Float32, Float32}, 3}:
[:, :, 1] =
(0.407247, 0.906906)  (0.225997, 0.790181)  (0.362782, 0.7733)  (0.368122, 0.550295)

[:, :, 2] =
(0.179379, 0.43576)  (0.28129, 0.769908)  (0.456139, 0.831805)  (0.264724, 0.976437)

[:, :, 3] =
(0.0381522, 0.984508)  (0.098552, 0.891221)  (0.494028, 0.996213)  (0.0941686, 0.899043)

[:, :, 4] =
(0.57849, 0.961374)  (0.0394041, 0.178669)  (0.0809094, 0.877981)  (0.0169927, 0.875119)

[:, :, 5] =
(0.0675674, 0.440386)  (0.0870846, 0.791125)  (0.700442, 0.856237)  (0.395846, 0.729174)

julia> minmax = reinterpret(Float32, minmax);

julia> minima = similar(minmax, size(minmax)[2:end]);

julia> Mem.unsafe_copy2d!(pointer(minima), Mem.Device, pointer(minmax), Mem.Device, 1, length(minima); srcPitch=2*sizeof(eltype(minmax)), dstPitch=0)

julia> reshape(minima, (1,4,5))
1×4×5 CuArray{Float32, 3}:
[:, :, 1] =
0.407247  0.225997  0.362782  0.368122

[:, :, 2] =
0.179379  0.28129  0.456139  0.264724

[:, :, 3] =
0.0381522  0.098552  0.494028  0.0941686

[:, :, 4] =
0.57849  0.0394041  0.0809094  0.0169927

[:, :, 5] =
0.0675674  0.0870846  0.700442  0.395846
``````

Ah, yes, `reinterpret` with a view would work. Though that would only support bitstype, so we would need two different paths. Unfortunately that would make the handling of all arrays (including `Array`) more complex, for a benefit that would only affect `CuArray`. Life is hard…

Thanks for your thoughts and ideas. To keep things simple, I propose to just merge the current PR; there are many benefits of that even without a resolution of this problem.

Once `extrema` is implemented for CUDA, we can perphaps write a near-optimal version for `_compute_extrema` for `CuArray`. A user who is looking for the last bit of performance will have to write a separate type that has a different organization of `tmin` and `tmax`. This is not too hard to write.