I’m struggling to understand why one function for multiplying a batch of matrices with a single matrix is significantly slower than another. The only difference between the functions is that PermutedDimsArray is used in place of permutedims in the slower function.
I would have thought permutedims would be slower given that it creates a copy of the array while PermutedDimsArray creates a new view. This makes me think I’ve misunderstood something fundamental with how Julia works.
function batched_mul_m(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T
a1, a2, bs = size(A)
b1, b2 = size(x)
C = reshape(permutedims(A, (1, 3, 2)), (bs * a1, a2)) * x
return permutedims(reshape(C, (a1, bs, b2)), (1, 3, 2))
end
function batched_mul_tk2(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T
a1, a2, bs = size(A)
b1, b2 = size(x)
C = reshape(PermutedDimsArray(A, (1, 3, 2)), (bs * a1, a2)) * x
return PermutedDimsArray(reshape(C, (a1, bs, b2)), (1, 3, 2))
end
M = randn(10,20,200)
x = randn(20,50)
@btime batched_mul_m(M,x);
# 169.902 μs (10 allocations: 1.83 MiB)
@btime batched_mul_tk2(M,x);
# 1.408 ms (21 allocations: 782.25 KiB)