Accelerate solving many matrix problems

So funny thing here: Even the fixed CUBLAS.gels_batched! is slower than the threaded CPU baseline. However, I was able to stitch together code snippets of other batched solution methods to get a nice performance bump. Note that the arrays above are of heterogeneous height, so I had to write a routine to pad each array with zeros.

#code is borrowed liberally from CUDA.jl/lib/cusolver/wrappers.jl
#                            and CUDA.jl/lib/cublas/wrappers.jl
for (rfname, rsname, T) in
    ((:cusolverDnSpotrfBatched,:cusolverDnSpotrsBatched,:Float32),
    (:cusolverDnDpotrfBatched,:cusolverDnDpotrsBatched,:Float64))
    @eval begin
function hyperreg(X::Vector{<:CuMatrix{$T}},
                      y::Vector{<:CuMatrix{$T}})
    cuuplo = CUDA.CUBLAS.cublasfill('U')

    if length(X) != length(y)
        throw(DimensionMismatch(""))
    end

    mx,nx = size(X[1])
    my,ny = size(y[1])

    if mx != my
        throw(ArgumentError("X Matrix must have same num of rows as y matrix"))
    end

    for (Xs, ys) in zip(X,y)
        if (size(Xs) != (mx, nx)) || (size(ys) != (my, ny))
            throw(DimensionMismatch("Dimensions of batched array entries must be invariant"))
        end
    end

    A = CUDA.CUBLAS.gemm_batched('T','N', X, X) #X'X
    B = CUDA.CUBLAS.gemm_batched('T','N', X, y) #X'y

    #Here we compute cholesky!(X'X)
    n = size(A[1],1)
    lda = max(1,stride(A[1],2))
    Aptrs = CUDA.CUBLAS.unsafe_batch(A)
    info  = zero(Cint)
    infoarray = CUDA.zeros(Cint, length(A))
    CUDA.CUSOLVER.$rfname(
      CUDA.CUSOLVER.dense_handle(), cuuplo, n, Aptrs, lda, infoarray, length(A))
    if info != 0
        throw(ArgumentError,string("Invalid value at ",-info))
    end

    # now solve C\(X'y) where C is the cholesky decomposition computed above
    nrhs = size(B[1])[2]
    ldb = max(1,stride(B[1],2))
    Bptrs = CUDA.CUBLAS.unsafe_batch(B)
    info  = zero(Cint)
    infoarray = CUDA.zeros(Cint, length(A))
    CUDA.CUSOLVER.$rsname(
      CUDA.CUSOLVER.dense_handle(), cuuplo, n, nrhs, Aptrs, lda, Bptrs, ldb, infoarray, length(A))
    CUDA.CUSOLVER.unsafe_free!(Aptrs)
    CUDA.CUSOLVER.unsafe_free!(Bptrs)

    if info != 0
        throw(ArgumentError,string("Invalid value at ",-info))
    end

    B
end end end


function pad0(M::TM, trows::Int, ::Type{T} = eltype(TM)) where {TM<:AbstractMatrix, T}
  padded = vcat(M, TM(zeros(T, trows-size(M,1), size(M,2))))
  sM = view(padded, 1:size(M,1), 1:size(M,2))
  return (sM, padded)
end

function pad0(V::TV, trows::Int, ::Type{T} = eltype(TV)) where {TV<:AbstractVector, T}
  padded = vcat(V, TV(zeros(T, trows-length(V))))
  sV = view(padded, 1:length(V))
  return (sV, padded)
end

function manyregmwe(T=1000,N=5000,K=10)

  #generate the data
  manysmallX = [rand(rand((N÷2):N),K) for i in 1:T]
  manysmally = [rand(size(x,1)) for x in manysmallX]

  linest(X,y) = cholesky!(X'*X) \ (X'*y)
  bs = [Vector{Float64}(undef, K) for i in 1:T]

  #cpu version
  function cpurunreg()
    Threads.@threads for i in 1:T
      @inbounds bs[i] .= linest(manysmallX[i], manysmally[i])
    end
  end

  manysmallXcu = cu.((X->pad0(X,N)[2]).(manysmallX))
  manysmallycu = cu.((y->Matrix(reshape(pad0(y,N)[2], N, 1))).(manysmally))#cu.(manysmally)

  function gpurunreg()
    bscu = hyperreg(manysmallXcu,manysmallycu)
    return bscu
  end
  cpurunreg()
  ys = gpurunreg()

  @info("cpu1: ")
  @btime $cpurunreg()

  @info("gpu: ")
  @btime $gpurunreg()

  #make sure the solutions are equivelent
  cpusol = reduce(hcat, bs)
  gpusol = reduce(hcat, vec.(Matrix.(ys)))
  cpusol ≈ gpusol || error("cpu ≠ gpu")

  return nothing
end


manyregmwe()

Output:

[ Info: cpu1:
  18.275 ms (7038 allocations: 1.27 MiB)
[ Info: gpu:
  3.965 ms (16141 allocations: 490.14 KiB)