Using TensorOperations.jl when contracting tensors not following strict Einstein summation convention

I am trying to implement the following computation on a GPU.

@tensor S[i,j] = A[a1,j]*F[a1,i,a2]*B[a2,j]

As per the strict Einstein summation convention, this is not a valid expression as the index j appears twice on the right-hand side and, as such, does not work when used with TensorOperations.jl.
At the moment, I create a larger intermediate array,

@tensor temporary_location[i,j,jprime] = A[a1,j]*F[a1,i,a2]*B[a2,jprime]

and then copy the diagonal part (i.e., temporary_location[i,j,j]) to the array S. I would like to know if there is a way to avoid performing the unnecessary computations involved in the above workaround. The range of indices a1, i, and j are of the same order (up to a few hundred), and the range of the index a2 is much smaller than a1, i, or j.

I tried using Tullio.jl, but it seems about 20-25x times slower than TensorOperations.jl.

Any help would be appreciated.

How does this compare to what you have now?

using TensorCast, TensorOperations
@cast tmp[a1, a2, j] := A[a1, j] * B[a2, j]
@tensor S[i, j] := F[a1, i, a2] * tmp[a1, a2, j]

For maximal performance, it is probably better to just call the relevant CUBLAS functions for this contraction, as it is not particularly difficult.

temp = reshape(F, sizea1 * sizei, sizea2) * B
temp2 = reshape(temp, sizea1, sizei, sizej)

The final step is then a batched matrix vector multiplication, if you read it as
S[i, j] = temp2[a1, i, j] * A[a1, j]. Here I found the function
gemv_strided_batched!, so I think something like this should work:

gemv_strided_batched!('T', 1, temp2, A, 0., S)

There seems to be some issue with @cast.
As per your suggestion, I use

@cast tmp[a1,a2,j] := A[a1,j]*B[a2,j]

I use arrays A of shape (28,90) and B of shape (9,90), but the output has shape (180,9,180) instead of shape (28,9,90). I am unsure what’s wrong.

That should work, and does for me, I don’t know how you got that size. Here’s working code for all three versions, although not tried on a GPU:

A = rand(28,90);
B = rand(9,90);
F = rand(28,100,9);
using TensorOperations, Tullio, TensorCast, NNlib, TensorCore

@tullio S1[i,j] := A[a1,j]*F[a1,i,a2]*B[a2,j];  # original expression
size(S1)  # (100, 90)
# CPU: min 167.541 μs, mean 173.655 μs (49 allocations, 73.17 KiB)

@cast tmp[a1, a2, j] := A[a1, j] * B[a2, j];
size(tmp)  # (28, 9, 90)
@tensor S2[i, j] := F[a1, i, a2] * tmp[a1, a2, j];
S1 ≈ S2
# CPU: min 57.500 μs, mean 99.750 μs (16 allocations, 445.23 KiB)

temp2 = boxdot(F, B);
size(temp2)  # (28, 100, 90)
temp3 = batched_mul(batched_transpose(temp2), reshape(A, size(A,1), 1, :));
S3 = reshape(temp3, size(temp3,1), :);
S1 ≈ S3
# CPU: min 301.583 μs, mean 485.143 μs (47 allocations, 2.00 MiB)

My bad. I used a different matrix than the one I intended to use.
With

using TensorCast, TensorOperations
@cast tmp[a1, a2, j] := A[a1, j] * B[a2, j]
@tensor S[i, j] := F[a1, i, a2] * tmp[a1, a2, j]

I see a 2 times increase in speed in the case where dim(a1) = 28, dim(a2) = 9, dim(i)=dim(j)=90. Thanks.

When I try your suggestion, I get the following error:

ERROR: CUBLASError: an invalid value was used as an argument (code 7, CUBLAS_STATUS_INVALID_VALUE)
Stacktrace:
 [1] throw_api_error(res::CUDA.CUBLAS.cublasStatus_t)
   @ CUDA.CUBLAS ~/work/.julia/packages/CUDA/rXson/lib/cublas/libcublas.jl:11
 [2] check
   @ ~/work/.julia/packages/CUDA/rXson/lib/cublas/libcublas.jl:21 [inlined]
 [3] cublasZgemvStridedBatched
   @ ~/work/.julia/packages/CUDA/rXson/lib/utils/call.jl:26 [inlined]
 [4] gemv_strided_batched!(trans::Char, alpha::Int64, A::CuArray{…}, x::CuArray{…}, beta::Float64, y::CuArray{…})
   @ CUDA.CUBLAS ~/work/.julia/packages/CUDA/rXson/lib/cublas/wrappers.jl:420

I am not sure what’s wrong. 'T" is for transpose, right? (In case it makes any difference, although, looking at the definition of gemv_strided_batched! it shouldn’t, my matrices are complex)

Can you post a fully functional minimal code (minimal working example, or in this case, not working but error reproducing)? For the contraction you want to do, it should indeed be transpose, even if the data is complex.

I am unsure what information this can convey, so do let me know if you need more.
This is the code I am trying to execute:

    display(temp2)
    display(exp_θ_gpu)
    display(slater_det_gpu)
    CUDA.CUBLAS.gemv_strided_batched!('T', 1, temp2, exp_θ_gpu, 0., slater_det_gpu)

which gives me

3×4×4 CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer}:
[:, :, 1] =
 0.0+0.0im  0.155265+0.170144im       0.24062-0.219578im     -0.155265-0.170144im
 1.0+0.0im  -0.31053+0.689596im  -9.20336e-33-4.14433e-33im   -0.31053+0.689596im
 0.0+0.0im  0.155265-0.85974im        1.21586+0.219578im     -0.155265+0.85974im

[:, :, 2] =
 0.0+0.0im  -0.0157624+0.595112im     0.841615+0.0222914im    0.0157624-0.595112im
 1.0+0.0im   0.0315249-0.160339im  2.13989e-33+4.20731e-34im  0.0315249-0.160339im
 0.0+0.0im  -0.0157624-0.434772im     0.614861-0.0222914im    0.0157624+0.434772im

[:, :, 3] =
 0.0+0.0im  -0.0675479+0.811143im     1.14713+0.0955272im    0.0675479-0.811143im
 1.0+0.0im    0.135096-0.592402im  7.9062e-33+1.80299e-33im   0.135096-0.592402im
 0.0+0.0im  -0.0675479-0.218741im    0.309346-0.0955272im    0.0675479+0.218741im

[:, :, 4] =
 0.0+0.0im   0.131172+0.479609im       0.67827-0.185505im     -0.131172-0.479609im
 1.0+0.0im  -0.262344+0.070665im  -9.43096e-34-3.50125e-33im  -0.262344+0.070665im
 0.0+0.0im   0.131172-0.550274im      0.778205+0.185505im     -0.131172+0.550274im
3×4 CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}:
 0.159602+0.987181im  0.658391+0.752676im  -0.134971+0.99085im  -0.924661+0.380791im
      1.0-0.0im            1.0-0.0im             1.0-0.0im            1.0-0.0im
 0.159602-0.987181im  0.658391-0.752676im  -0.134971-0.99085im  -0.924661-0.380791im
4×4 CuArray{ComplexF64, 2, CUDA.Mem.DeviceBuffer}:
 0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im
 0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im
 0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im
 0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im
ERROR: CUBLASError: an invalid value was used as an argument (code 7, CUBLAS_STATUS_INVALID_VALUE)
Stacktrace:
 [1] throw_api_error(res::CUDA.CUBLAS.cublasStatus_t)
   @ CUDA.CUBLAS ~/work/.julia/packages/CUDA/rXson/lib/cublas/libcublas.jl:11
 [2] check
   @ ~/work/.julia/packages/CUDA/rXson/lib/cublas/libcublas.jl:21 [inlined]
 [3] cublasZgemvStridedBatched
   @ ~/work/.julia/packages/CUDA/rXson/lib/utils/call.jl:26 [inlined]
 [4] gemv_strided_batched!(trans::Char, alpha::Int64, A::CuArray{…}, x::CuArray{…}, beta::Float64, y::CuArray{…})
   @ CUDA.CUBLAS ~/work/.julia/packages/CUDA/rXson/lib/cublas/wrappers.jl:420