Efficient matrix multiplication over 4-dimensional arrays

I have two 4-dimensional arrays of sizes (r1, m, n, r2) and (s1, n, p, s2), respectively, and these have to be multiplied such that the resulting 4-dimensional array has shape (r1 * s1, m, p, r2 * s2). My naive implementation does the following:

function core_mult_col(lhs, rhs)

    r1, m, n, r2 = size(lhs)
    s1, n, p, s2 = size(rhs)

    dim_1 = r1 * s1
    dim_2 = m
    dim_3 = p
    dim_4 = r2 * s2

    # Pre-allocate memory for resulting core
    result = Array{Float64}(undef, (r1 * s1, m, p, r2 * s2))
 

    for j2 = 1:s2, i2 = 1:r2, j1 = 1:s1, i1 = 1:r1

        i = (i1-1) * s1 + j1
        j = (i2-1) * s2 + j2

        mul!(result[i, :, :, j], lhs[i1, :, :, i2] , rhs[j1, :, :, j2])

    end
end

However, I know that in this way I am making a lot of “jumps” in memory and therefore, I’m not being very efficient.

My other idea was to permute the dimensions of all the arrays such that the iteration occurs in contiguous blocks of memory:

function core_mult_col(lhs, rhs)

    r1, m, n, r2 = size(lhs)
    s1, n, p, s2 = size(rhs)

    dim_1 = r1 * s1
    dim_2 = m
    dim_3 = p
    dim_4 = r2 * s2

    permuted_lhs = Array{Float64}(undef, (r1, r2, m, n))
    permuted_rhs = Array{Float64}(undef, (s1, s2, n, p))

    permutedims!(permuted_lhs, lhs, (1, 4, 2, 3))
    permutedims!(permuted_rhs, rhs, (1, 4, 2, 3))

    pre_result = Array{Float64}(undef, (r1 * s1, r2 * s2, m, p))

    # here I would like to iterate over the first two dimensions of 
    # each array efficiently and multiply the slices permuted_lhs[i, j, :, :], permuted_rhs[i, j, :, :]

    result  = Array{Float64}(undef, (r1 * s1, m, p, r2 * s2)

    result = permutedims!(pre_result, (1, 3, 4, 2))

    return result

end

However, I’m not sure how to do this iteration efficiently. Is there a way to do this in Julia specifically?

PS: I already used the TensorOperations.jl package, but the performance was slower in comparison to the implementation that I had in Python.

Thanks in advance.

How fast is the Tullio version of this?

using Tullio, LoopVectorization
function f(A,B)
    @tullio C[r1, s1, m, p, r2, s2] := A[r1, m, n, r2] * B[s1, n, p, s2]
    return reshape(C, r1*s1, m, p, r2*s2)
end

Thanks for the reply! It also seems to be slower by a few 100 ms than the TensorOperations version.

Did you do using LoopVectorization as well? If LV is loaded, Tullio can generate faster code

Can you check that the code above works? I get parsing errors, and then bounds errors. And mul!(result[i, :, :, j] makes a copy & writes into that, it wants @views in front to write into result.

For many sizes, Tullio + LoopVectorization is likely to be faster than permutedims + matmul. Which is what TensorOperations will re-write this to be.

Yes, I am using it in the first cell of my Pluto notebook. I’m not sure whether it makes a difference where I put it.

You are correct, I had a typo when allocating memory for result. It was

result = Array{Float64}(undef, (r1 * s1, m, p , r2 * s2))

instead of

result = Array{Float64}(undef, (r1, s1, m, p , r2, s2))

I already corrected it in the original post.

I think that still needs @views. It does not agree with the results from the index packages (using what Oscar wrote), but I haven’t tried to track down why. Timing & comparing:

julia> using TensorOperations, Tullio, LoopVectorization

julia> function f_tensor(A,B)
           @tensor C[r1, s1, m, p, r2, s2] := A[r1, m, n, r2] * B[s1, n, p, s2]
           r1, m, n, r2 = size(A)
           s1, n, p, s2 = size(B)
           return reshape(C, r1*s1, m, p, r2*s2)
       end
f_tensor (generic function with 1 method)

julia> function f_tullio(A,B)
           @tullio C[r1, s1, m, p, r2, s2] := A[r1, m, n, r2] * B[s1, n, p, s2]
           r1, m, n, r2 = size(A)
           s1, n, p, s2 = size(B)
           return reshape(C, r1*s1, m, p, r2*s2)
       end
f_tullio (generic function with 1 method)

julia> let lhs = rand(2,3,4,5), rhs = rand(3,4,5,2)
        y1 = @btime core_mult_col_views($lhs, $rhs)  # with @views, and return result
        y2 = @btime f_tensor($lhs, $rhs)
        y3 = @btime f_tullio($lhs, $rhs)
        y1 ≈ y2, y2 ≈ y3
       end
  min 57.208 μs, mean 218.863 μs (722 allocations, 1.24 MiB)
  min 16.042 μs, mean 19.183 μs (102 allocations, 14.33 KiB)
  min 1.217 μs, mean 2.493 μs (15 allocations, 7.83 KiB)
(false, true)