Fast computation of row-wise Kronecker product (Khatri-Rao product)

Hello! I have recently started to migrate to Julia, and decided to port some of my Matlab code. I am struggling to make a performant row-wise Kronecker product (Khatri-Rao product) implementation for two or three matrices. At the moment I am doing the following:

function dotkron!(A::Matrix{Float64},B::Matrix{Float64},C::Matrix{Float64})
    N = size(A,1)
    @inbounds @simd for n = 1:N
        kron!(A[n,:],B[n,:],C[n,:])
    end
end

function dotkron(A::Matrix{Float64},B::Matrix{Float64})
    (N,DA) = size(A)
    (N,DB) = size(B)
    C = Matrix{Float64}(undef,N,DA*DB)
    dotkron!(C,A,B)
    return C
end

function dotkron!(A::Matrix{Float64},B::Matrix{Float64},C::Matrix{Float64},D::Matrix{Float64})
    dotkron!(A,dotkron(B,C),D)
end

function dotkron(A::Matrix{Float64},B::Matrix{Float64},C::Matrix{Float64})
    (N,DA) = size(A)
    (N,DB) = size(B)
    (N,DC) = size(C)
    D = Matrix{Float64}(undef,N,DA*DB*DC)
    dotkron!(D,A,B,C)
    return D
end

can it be done more neatly and faster? For reference on my system:

using BenchmarkTools
A=rand(1000000,10);
B=rand(1000000,10);
@btime dotkron(A,B);
  1.229 s (9000002 allocations: 2.15 GiB)

Matlab is approx. 3-4 times faster.
(Apologies for posting here, it is my first post, but this seems the more fitting category)

The big improvement would be

    @inbounds @simd for n = 1:N
        @views kron!(A[n,:],B[n,:],C[n,:])
    end

I time that as 2x faster, and much less memory.

3 Likes

I think you need @views instead of @simd here, otherwise the A[n,:] into which you are writing is a copy, and the original A is unchanged.

But you can also do this all by reshaping and broadcasting, instead of making slices at all. Using my package for such things (although you can also just write it out by hand):

julia> M = reshape(1:6, 2,3) .+ 0.0;

julia> dotkron(M, ones(2,3))
2×9 Matrix{Float64}:
 1.0  1.0  1.0  3.0  3.0  3.0  5.0  5.0  5.0
 2.0  2.0  2.0  4.0  4.0  4.0  6.0  6.0  6.0

julia> reshape(reshape(M,2,1,:) .* ones(2,3), 2, :)
2×9 Matrix{Float64}:
 1.0  1.0  1.0  3.0  3.0  3.0  5.0  5.0  5.0
 2.0  2.0  2.0  4.0  4.0  4.0  6.0  6.0  6.0

julia> using TensorCast

julia> @cast C[i,(k,j)] := M[i,j] * ones(2,3)[i,k]
2×9 Matrix{Float64}:
 1.0  1.0  1.0  3.0  3.0  3.0  5.0  5.0  5.0
 2.0  2.0  2.0  4.0  4.0  4.0  6.0  6.0  6.0

julia> A, B, C = rand(10^5, 10), rand(10^5, 10), rand(10^5, 10); # slightly smaller!

julia> D = @btime dotkron($A, $B, $C);  # version with @views
  225.142 ms (4 allocations: 839.23 MiB)

julia> @btime @cast _[i,(l,k,j)] := $A[i,j] * $B[i,k] * $C[i,l];
  88.178 ms (11 allocations: 762.94 MiB)

julia> using Strided  # multi-threaded broadcasting package

julia> D2 = @btime @cast @strided _[i,(l,k,j)] := $A[i,j] * $B[i,k] * $C[i,l];
  53.937 ms (82 allocations: 762.95 MiB)

julia> D2 == D
true
1 Like

kron!.(eachrow(A),eachrow(B),eachrow(C)) is also pretty good (especially for how simple it is).

Edit: you get another 15% from transposing the problem and using kron!.(eachcol(A),eachcol(B),eachcol(C))

2 Likes

Have you looked at TensorDecompositions.jl? It’s old but you might be able to use their code for this.

There is also this question and answer on Kronecker.jl that might help. They would likely add your function if it is not already covered by the package, and also help write a performant implementation.

1 Like

Thank you all for these really good suggestions. I guess @views was the obvious one which I should not have missed, great improvement indeed! @mcabbott I will try your package, seems extremely promising!