Batched multiplication with CuSparseMatrix isn't working

Applying batched_mul from NNlib/NNlibCUDA to a CuArray tensor and a CuSparseMatrix matrix falls back to the generic method and performs scalar indexing. Is there a GPU-efficient way to broadcast multiplication of a CuSparseMatrix over a CuArray tensor?

p.s. There is currently an issue where the similar in batched_mul will cause the output to be a CPU array, but even fixing that does not seem to solve the scalar indexing problem.


a = CuArray(ones(2,2,3))
b = CuSparseMatrix(CuArray(ones(2,2)))
c = batched_mul(a,b) #This does scalar indexing

If there’s a CUSPARSE routine for this, it should be possible to add dispatch to send things there from batched_mul. But what should the output be if it did work?

julia> b = CuSparseMatrixCSC(cu(ones(2,2)))
2×2 CuSparseMatrixCSC{Float32} with 4 stored entries:sparse(Int32[1, 2, 1, 2], [1, 1, 2, 2], Float32[1.0, 1.0, 1.0, 1.0], 2, 2)

julia> similar(b, Float32, (2,2,2)) |> summary
"2×2×2 Array{Float32, 3}"

julia> CuSparse
CuSparseMatrix    CuSparseMatrixCOO  CuSparseMatrixCSR
CuSparseMatrixBSR CuSparseMatrixCSC  CuSparseVector

batched mull calls similar(reshape(b,2,2,1)) which currently results in an Array, but PR #1184 fixes this and makes it a CuArray, which is consistent with what happens when b is sparse and not CuSparse. batched_mul still leads to scalar indexing in this case, but I’m not sure where to add a method that will prevent this. Currently it gets sent to the generic method, and I’m not sure how to prevent that from happening (or maybe that’s fine, but there is something else that needs fixing?)

So you’d like it to return a dense 3-array?

Looking here, I don’t see any batched operations to hook this up to:

So my guess is that the most efficient thing will be to permute & reshape a to use ordinary matrix-matrix *, then permute back:

julia> using TensorCast

julia> @pretty @matmul C[i,j,k] := sum(s) A[i,s,k] * B[s,j]
    # ...
    local (ax_i, ax_j, ax_k, ax_s) = (axes(A, 1), axes(B, 2), axes(A, 3), axes(A, 2))
    local magpie = transmutedims(A, (1, 3, 2))
    local crow = reshape(magpie, (star(ax_i, ax_k), ax_s))
    local caterpillar = reshape(crow * B, (ax_i, ax_k, ax_j))
    local wasp = transmutedims(caterpillar, (1, 3, 2))
    C = wasp

Took me a while to understand what is going on here (wasn’t familiar with TensorCast or with animal variables), but this makes good sense. Thanks! Do you think this will run well on a GPU, considering axes(A,3) is very large, and the whole multiplication will probably be sent as a single kernel? If I can split A[i,s,k] along the third dimension into several iXsXm arrays, is there a way to send a the multiplication of B with the whole set off to the GPU at once, but with different kernels?

I think this should work on the GPU but haven’t checked. I believe that launching one big kernel is usually the best thing, but not with high confidence; you could try slicing and calling several. If I remember correctly, what batched_mul does here for dense arrays is more efficient than reshape-then-* like this, both for a ⊠ b and b ⊠ a, because it calls a different batched function which is spreads the work better.

TensorCast is mostly just providing a way to get the permutation (1, 3, 2) without having to think too hard; writing something like this by hand will be equivalent. (Using axes vs. size shouldn’t matter here. And transmutedims is faster than permutedims on Arrays but should fall back to being the same for CuArrays. And @pretty is a variant of @macroexpand1.)

1 Like

Great. Thanks for the help!