To do this as one ordinary matrix multiplication, you need to permute as you do. TensorOperations is often quicker at this, by caching, and by having a faster permutedims
implementation. Although today only marginally:
julia> @btime batched_mul_m($M, $x);
200.668 μs (10 allocations: 1.83 MiB)
julia> using TensorOperations
julia> batched_mul_tk3(A,x) = @tensor D[i,l,k] := A[i,j,k] * x[j,l];
julia> @btime batched_mul_tk3($M, $x);
191.240 μs (177 allocations: 795.13 KiB)
This turns out to be a case where avoiding BLAS entirely pays off, as you don’t have to permute, which is actually taking a majority of the time above:
julia> @btime permutedims($M, (1, 3, 2));
35.852 μs (2 allocations: 312.58 KiB)
julia> @btime permutedims(reshape($C, (a1, bs, b2)), (1, 3, 2));
90.409 μs (6 allocations: 781.47 KiB)
julia> using Tullio, LoopVectorization
julia> batched_mul_tk4(A,x) = @tullio D[i,l,k] := A[i,j,k] * x[j,l];
julia> batched_mul_tk4(M,x) ≈ batched_mul_tk3(M,x) ≈ batched_mul_m(M,x)
true
julia> @btime batched_mul_tk4($M, $x);
62.684 μs (51 allocations: 784.47 KiB)
Finally I think this one is also batched_mul
without permutation (some PermutedDimsArray
s are actually OK, but not needed here). With the right branch:
julia> using NNlib # PR#191
julia> @btime batched_mul($M, reshape($x,20,50,1));
100.661 μs (25 allocations: 784.34 KiB)