PermutedDimsArray slower than permutedims?

I’m struggling to understand why one function for multiplying a batch of matrices with a single matrix is significantly slower than another. The only difference between the functions is that PermutedDimsArray is used in place of permutedims in the slower function.

I would have thought permutedims would be slower given that it creates a copy of the array while PermutedDimsArray creates a new view. This makes me think I’ve misunderstood something fundamental with how Julia works.

function batched_mul_m(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T 
    a1, a2, bs = size(A)
    b1, b2 = size(x)
    C = reshape(permutedims(A, (1, 3, 2)), (bs * a1, a2)) * x
    return permutedims(reshape(C, (a1, bs, b2)), (1, 3, 2))
end

function batched_mul_tk2(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T 
    a1, a2, bs = size(A)
    b1, b2 = size(x)
    C = reshape(PermutedDimsArray(A, (1, 3, 2)), (bs * a1, a2)) * x
    return PermutedDimsArray(reshape(C, (a1, bs, b2)), (1, 3, 2))
end

M = randn(10,20,200)
x = randn(20,50)

@btime batched_mul_m(M,x);
# 169.902 μs (10 allocations: 1.83 MiB)

@btime batched_mul_tk2(M,x);
# 1.408 ms (21 allocations: 782.25 KiB)
2 Likes

With the copy of the array, the data is arranged in the right way so that Julia can call optimized BLAS routines for the matrix multiplication. With the PermutedDimsArray, probably Julia has to fall back to a slow generic matrix-multiplication routine. (The time for the copy is surely negligible compared to the time for the matrix multiplication, since the former is linear and the latter is superlinear in the size of the array.)

3 Likes

Ok thanks for letting me know – your explanation makes sense! I assumed that Julia would be able to call BLAS routines on all Arrays of floats.

To do this as one ordinary matrix multiplication, you need to permute as you do. TensorOperations is often quicker at this, by caching, and by having a faster permutedims implementation. Although today only marginally:

julia> @btime batched_mul_m($M, $x);
  200.668 μs (10 allocations: 1.83 MiB)

julia> using TensorOperations
julia> batched_mul_tk3(A,x) = @tensor D[i,l,k] := A[i,j,k] * x[j,l];

julia> @btime batched_mul_tk3($M, $x);
  191.240 μs (177 allocations: 795.13 KiB)

This turns out to be a case where avoiding BLAS entirely pays off, as you don’t have to permute, which is actually taking a majority of the time above:

julia> @btime permutedims($M, (1, 3, 2));
  35.852 μs (2 allocations: 312.58 KiB)

julia> @btime permutedims(reshape($C, (a1, bs, b2)), (1, 3, 2));
  90.409 μs (6 allocations: 781.47 KiB)

julia> using Tullio, LoopVectorization
julia> batched_mul_tk4(A,x) = @tullio D[i,l,k] := A[i,j,k] * x[j,l];

julia> batched_mul_tk4(M,x) ≈ batched_mul_tk3(M,x) ≈ batched_mul_m(M,x)
true

julia> @btime batched_mul_tk4($M, $x);
  62.684 μs (51 allocations: 784.47 KiB)

Finally I think this one is also batched_mul without permutation (some PermutedDimsArrays are actually OK, but not needed here). With the right branch:

julia> using NNlib # PR#191

julia> @btime batched_mul($M, reshape($x,20,50,1));
  100.661 μs (25 allocations: 784.34 KiB)
4 Likes