Argmax mapreduce on GPU

Hello! I am trying to quickly compute

\text{argmax}_{\substack{1 \leq s \leq k\\ k+1 \leq t \leq n}} A_{s,t} + (1-\ell_s)(1 + \ell_t)

I do this on the CPU with the following code.

f = ((i, j),) -> (i, j, A[i, j]^2 + (1 - l[i])*(1 + l[j]))
op = (x, y) -> x[3] > y[3] ? x : y
col = product(1:k, (k+1):n)
i, j, volf = Folds.mapreduce(f, op, col; init=(0, 0, -Inf))

I would like to convert the following code to something that can utilize CUDA. I have done this as follows.

C = CuMatrix{Float64}(k, n-k)
copyto!(C, view(A, :, k+1:n))
C .^= 2
CUBLAS.ger!(1.0, 1 .- l1, 1 .+ l2, C)
    
s = argmax(C)
I = CartesianIndices(C)[s]
i, j = I[1], I[2]
@allowscalar volf = C[s]

However, I’d like to do this without writing each element of C to VRAM similarly as with the CPU version. I would prefer to avoid writing my own CUDA kernel.

I have looked into using FoldsCUDA.jl but it seems to be deprecated and doesn’t support resent CUDA versions. It is also not maintained by JuliaFolds2.

Any suggestions?

GitHub - JuliaGPU/AcceleratedKernels.jl: Cross-architecture parallel algorithms for Julia's CPU and GPU backends. Targets multithreaded CPUs, and GPUs via Intel oneAPI, AMD ROCm, Apple Metal, Nvidia CUDA. might be a good place to look at for code like this

Good suggestion, but see here. In particular, this type of thing β€œβ€¦ just go a little out of scope for AK.”

I think the reduction kernels in GPUArrays can act on lazy Broadcasted objects, so you can probably make them do this for you:

julia> begin
       n = 10
       A = randn(n, n)
       l = randn(n)
       f = ((i, j),) -> (i, j, A[i, j]^2 + (1 - l[i])*(1 + l[j]))
       op = (x, y) -> x[3] > y[3] ? x : y
       col = Iterators.product(1:n, 1:n) # simplified from product(1:k, (k+1):n), just make views of A, l as necc.
       i, j, volf = mapreduce(f, op, col; init=(0, 0, -Inf))
       end
(7, 3, 13.250911638285407)

julia> argmax(@. A^2 + (1 - l)*(1 + l'))
CartesianIndex(7, 3)

julia> Meta.@lower @. A^2 + (1 - l)*(1 + l')
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1 ─ %1  = +
β”‚   %2  = ^
β”‚   %3  = A
β”‚   %4  = Core.apply_type(Base.Val, 2)
β”‚   %5  = (%4)()
β”‚   %6  = Base.broadcasted(Base.literal_pow, %2, %3, %5)
β”‚   %7  = *
β”‚   %8  = Base.broadcasted(-, 1, l)
β”‚   %9  = +
β”‚   %10 = var"'"(l)
β”‚   %11 = Base.broadcasted(%9, 1, %10)
β”‚   %12 = Base.broadcasted(%7, %8, %11)
β”‚   %13 = Base.broadcasted(%1, %6, %12)
β”‚   %14 = Base.materialize(%13)
└──       return %14
))))

julia> function lazy(A, l)
       x6 = Base.broadcasted(Base.literal_pow, ^, A, Val(2))
       x8  =Base.broadcasted(-, 1, l)
       x11 = Base.broadcasted(+, 1, l')
       x12 = Base.broadcasted(*, x8, x11)
       x13 = Base.broadcasted(+, x6, x12)
       end
lazy (generic function with 1 method)

# eager

julia> argmax(Base.materialize(lazy(A, l)))
CartesianIndex(7, 3)

julia> using JLArrays

julia> argmax(Base.materialize(lazy(jl(A), jl(l))))
CartesianIndex(7, 3)

# lazy

julia> maximum(lazy(A, l))  # just iterating, I believe
13.250911638285407

julia> maximum(x for x in lazy(A, l))
13.250911638285407

julia> maximum(lazy(jl(A), jl(l)))  # using GPUArrays reduction, as iteration fails
13.250911638285407

julia> maximum(x for x in lazy(jl(A), jl(l)))
ERROR: Scalar indexing is disallowed.

# argmax

julia> argmax(lazy(A, l))
ERROR: MethodError: no method matching keys(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{…}, Nothing, typeof(+), Tuple{…}})

julia> Base.keys(bc::Base.Broadcast.Broadcasted) = CartesianIndices(axes(bc))

julia> argmax(lazy(A, l))
CartesianIndex(7, 3)

julia> argmax(lazy(jl(A), jl(l)))
ERROR: Scalar indexing is disallowed.

# better idea

julia> function lazy3(A, l)
         a1, a2 = axes(A)
         bc = lazy(A, l)
         x14 = Base.broadcasted(tuple, bc, a1, a2')
       end
lazy3 (generic function with 1 method)

julia> maximum(lazy3(A, l))
(13.250911638285407, 7, 3)

julia> maximum(lazy3(jl(A), jl(l)))
ERROR: MethodError: no method matching typemin(::Type{Tuple{Float64, Int64, Int64}})
Stacktrace:
 [1] neutral_element(::typeof(max), T::Type)
   @ GPUArrays ~/.julia/packages/GPUArrays/ouBUA/src/host/mapreduce.jl:25
 [2] _mapreduce(f::typeof(identity), op::typeof(max), As::Base.Broadcast.Broadcasted{…}; dims::Colon, init::Nothing)
   @ GPUArrays ~/.julia/packages/GPUArrays/ouBUA/src/host/mapreduce.jl:49
...

julia> Base.typemin(::Type{Tuple{T,I,J}}) where {T,I,J} = map(typemin, (T,I,J))  # piracy... could overload GPUArrays .neutral_element instead

julia> maximum(lazy3(jl(A), jl(l)))
(13.250911638285407, 7, 3)
1 Like