Applying batched_mul from NNlib
/NNlibCUDA
to a CuArray tensor and a CuSparseMatrix matrix falls back to the generic method and performs scalar indexing. Is there a GPU-efficient way to broadcast multiplication of a CuSparseMatrix over a CuArray tensor?
p.s. There is currently an issue where the similar
in batched_mul
will cause the output to be a CPU array, but even fixing that does not seem to solve the scalar indexing problem.
MWE:
using CUDA, CUDA.CUSPARSE, NNlib, NNlibCUDA
CUDA.allowscalar(false)
a = CuArray(ones(2,2,3))
b = CuSparseMatrix(CuArray(ones(2,2)))
c = batched_mul(a,b) #This does scalar indexing
If there’s a CUSPARSE routine for this, it should be possible to add dispatch to send things there from batched_mul
. But what should the output be if it did work?
julia> b = CuSparseMatrixCSC(cu(ones(2,2)))
2×2 CuSparseMatrixCSC{Float32} with 4 stored entries:sparse(Int32[1, 2, 1, 2], [1, 1, 2, 2], Float32[1.0, 1.0, 1.0, 1.0], 2, 2)
julia> similar(b, Float32, (2,2,2)) |> summary
"2×2×2 Array{Float32, 3}"
julia> CuSparse
CuSparseMatrix CuSparseMatrixCOO CuSparseMatrixCSR
CuSparseMatrixBSR CuSparseMatrixCSC CuSparseVector
batched mull calls similar(reshape(b,2,2,1))
which currently results in an Array
, but PR #1184 fixes this and makes it a CuArray
, which is consistent with what happens when b
is sparse and not CuSparse. batched_mul
still leads to scalar indexing in this case, but I’m not sure where to add a method that will prevent this. Currently it gets sent to the generic method, and I’m not sure how to prevent that from happening (or maybe that’s fine, but there is something else that needs fixing?)
So you’d like it to return a dense 3-array?
Looking here, I don’t see any batched operations to hook this up to:
So my guess is that the most efficient thing will be to permute & reshape a
to use ordinary matrix-matrix *
, then permute back:
julia> using TensorCast
julia> @pretty @matmul C[i,j,k] := sum(s) A[i,s,k] * B[s,j]
begin
# ...
local (ax_i, ax_j, ax_k, ax_s) = (axes(A, 1), axes(B, 2), axes(A, 3), axes(A, 2))
local magpie = transmutedims(A, (1, 3, 2))
local crow = reshape(magpie, (star(ax_i, ax_k), ax_s))
local caterpillar = reshape(crow * B, (ax_i, ax_k, ax_j))
local wasp = transmutedims(caterpillar, (1, 3, 2))
C = wasp
end
Took me a while to understand what is going on here (wasn’t familiar with TensorCast
or with animal variables), but this makes good sense. Thanks! Do you think this will run well on a GPU, considering axes(A,3)
is very large, and the whole multiplication will probably be sent as a single kernel? If I can split A[i,s,k]
along the third dimension into several iXsXm
arrays, is there a way to send a the multiplication of B
with the whole set off to the GPU at once, but with different kernels?
I think this should work on the GPU but haven’t checked. I believe that launching one big kernel is usually the best thing, but not with high confidence; you could try slicing and calling several. If I remember correctly, what batched_mul
does here for dense arrays is more efficient than reshape-then-*
like this, both for a ⊠ b
and b ⊠ a
, because it calls a different batched function which is spreads the work better.
TensorCast is mostly just providing a way to get the permutation (1, 3, 2)
without having to think too hard; writing something like this by hand will be equivalent. (Using axes
vs. size
shouldn’t matter here. And transmutedims
is faster than permutedims
on Arrays but should fall back to being the same for CuArrays. And @pretty
is a variant of @macroexpand1
.)
1 Like
Great. Thanks for the help!