Fast computation of row-wise Kronecker product (Khatri-Rao product)

I think you need @views instead of @simd here, otherwise the A[n,:] into which you are writing is a copy, and the original A is unchanged.

But you can also do this all by reshaping and broadcasting, instead of making slices at all. Using my package for such things (although you can also just write it out by hand):

julia> M = reshape(1:6, 2,3) .+ 0.0;

julia> dotkron(M, ones(2,3))
2×9 Matrix{Float64}:
 1.0  1.0  1.0  3.0  3.0  3.0  5.0  5.0  5.0
 2.0  2.0  2.0  4.0  4.0  4.0  6.0  6.0  6.0

julia> reshape(reshape(M,2,1,:) .* ones(2,3), 2, :)
2×9 Matrix{Float64}:
 1.0  1.0  1.0  3.0  3.0  3.0  5.0  5.0  5.0
 2.0  2.0  2.0  4.0  4.0  4.0  6.0  6.0  6.0

julia> using TensorCast

julia> @cast C[i,(k,j)] := M[i,j] * ones(2,3)[i,k]
2×9 Matrix{Float64}:
 1.0  1.0  1.0  3.0  3.0  3.0  5.0  5.0  5.0
 2.0  2.0  2.0  4.0  4.0  4.0  6.0  6.0  6.0

julia> A, B, C = rand(10^5, 10), rand(10^5, 10), rand(10^5, 10); # slightly smaller!

julia> D = @btime dotkron($A, $B, $C);  # version with @views
  225.142 ms (4 allocations: 839.23 MiB)

julia> @btime @cast _[i,(l,k,j)] := $A[i,j] * $B[i,k] * $C[i,l];
  88.178 ms (11 allocations: 762.94 MiB)

julia> using Strided  # multi-threaded broadcasting package

julia> D2 = @btime @cast @strided _[i,(l,k,j)] := $A[i,j] * $B[i,k] * $C[i,l];
  53.937 ms (82 allocations: 762.95 MiB)

julia> D2 == D
true
1 Like