Batch matrix/vector operations with CUDA.jl

I’m trying to perform various operations such as multiplication, inversion, solving, Cholesky decomposition in batches with CUDA.jl.

As far as I’m aware, there is no high-level API for doing this, as is the case with Tensorflow/PyTorch, e.g. tf.linalg.inv which natively supports batched operations.

Instead, I have been trying to use the low-level CUBLAS wrappers. I would have expected these to accept multi-dimensional arrays […, N, N] but instead most of them seem to accept vectors of CuArrays.

Here is an example using gemv_batched!.

I wanted to confirm two things:

  1. Is this the correct way to be performing batched operations with CUDA.jl?
  2. Are these operations actually parallelised over the batch or just within each single operation (it’s the vector of CuArrays that is making me doubt this)?

As a follow-up question, are there any plans to implement a Tensorflow/PyTorch style high-level interface for batch operations. If so, I would be happy to help where I can as my research would benefit greatly from such implementations.

1 Like

The strided_batched methods are the ones accepting multi-dimensional arrays (as per NVIDIA naming), and are supposed to be faster than the ones using vectors of GPU arrays: https://developer.nvidia.com/blog/cublas-strided-batched-matrix-multiply/

Other than that, I’m not terribly familiar with the use or design of batched APIs, so help is always appreciated. There is some existing work though, like batched_mul! in NNlib.jl, Batched.jl, BatchedBLAS.jl, etc.

2 Likes

We have a few low-level wrappers for batch operations in CUBLAS and CUSOLVER:

I confirm that these operations are parallelized over the batch.
I am reluctant to adding an high-level dispatch (mul!, cholesky, etc…) because we don’t have a “batch” version for CPU.

2 Likes

Thank you both for clarifying!

1 Like

This is a very fair point and I can see how this could lead to confusion. How would you feel about adding a high-level interface specifically for the non-strided batch case?

I believe all of these are already supported by Julia through automatic broadcasting. E.g.

As = [rand(2, 2) for i in 1:10]
Bs = [rand(2, 3) for i in 1:10]
As .* Bs

We could just intercept the broadcasting and replace it with a call to gemm_batched.

It’s an excellent idea @TimHargreaves, if we can intercept the .*, it makes sense to dispatch to gemm_batched.

We just need to check if a recent CUDA release is installed (12.4) because the matrices in As and Bs needed to have the same shape before.
It also explains why we never think of that, the CUDA routine was too restrictive on As and Bs.