Dear Julia users,
I am trying to set up the normal equations ( https://en.wikipedia.org/wiki/Linear_least_squares ) which means constructing X'X and X'y in
(X'X+\lambda M) w = X'y
where M is SPD, and in my case, each row of X is a Kronecker product of two matrices, i.e.:
X_{i,:}=A_{i,:} \otimes B_{i,:}
This means that X has Khatri-Rao structure. In Matlab I am using propert #4 in https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product to perform the Khatri-Rao product for n rows at once. In Julia, I would like to avoid this in order to take advantage of the fact that the code is compiled. However all my attempts at generating a faster version of the code fail. In the code snippet below I tried to use naive rank-1 updates (“normal”), using a for loop with LoopVectorization, and Tullio (@mcabbott helped me greatly in a previous question on the same topic). I find that Tullio is very fast however the problem is that since it needs to be called twice (once for X'X and once for X'y), it ends up being slower that the for loop with @tturbo on my machine. I wonder if there is any way to speed the code-up further taking advantage of Julia. Here are the times on my machines:
4.333 s (226937 allocations: 6.06 MiB) Naive
1.558 s (226937 allocations: 6.06 MiB) Naive @turbo
1.531 s (226937 allocations: 6.06 MiB) Naive @tturbo
610.630 ms (51 allocations: 8.47 KiB) Loop @turbo
248.648 ms (51 allocations: 8.47 KiB) Loop @tturbo
257.008 ms (129 allocations: 6.17 KiB) Tullio
147.583 ms (43 allocations: 197.54 MiB) Matlab port
165.664 ms (35 allocations: 188.00 MiB) Wrongly improved Matlab port
For reference Matlab takes 0.0869 seconds to run this code (almost twice as fast as the Matlab port).
This is currently the only part of my code in which Matlab performs better. If I can Julia to run faster, I would abandon Matlab for good.
N = 10000;
R = 20;
M = 40;
A = rand(N,M);
B = rand(N,R);
C = rand(M*R,M*R);
x = rand(N,);
Cx = rand(M*R,);
function normal!(C,Cx,x,A,B)
temp = Vector{Float64}(undef,M*R)
@fastmath @views @inbounds for n = 1:N
kron!(temp,A[n,:],B[n,:])
C .= C.+ temp.*temp'
Cx .= Cx .+ temp.*x[n]
end
return nothing
end
function normalTurbo!(C,Cx,x,A,B)
temp = Vector{Float64}(undef,M*R)
@fastmath @views @inbounds for n = 1:N
kron!(temp,A[n,:],B[n,:])
@turbo C .= C.+ temp.*temp'
@turbo Cx .= Cx .+ temp.*x[n]
end
return nothing
end
function normalTturbo!(C,Cx,x,A,B)
temp = Vector{Float64}(undef,M*R)
@fastmath @views @inbounds for n = 1:N
kron!(temp,A[n,:],B[n,:])
@tturbo C .= C.+ temp.*temp'
@tturbo Cx .= Cx .+ temp.*x[n]
end
return nothing
end
function normalEquations!(C,Cx,x,A,B)
temp = Matrix{Float64}(undef,M,R)
C4 = reshape(C, M,R,M,R) # writing into original C
C4x = reshape(Cx, M,R)
@fastmath @inbounds @turbo for m = 1:M
for r = 1:R
for mm = 1:M
for rr = 1:R
for n = 1:N
temp = A[n,r] * B[n,m]
C4[m,r,mm,rr] = C4[m,r,mm,rr] + temp * A[n,rr] * B[n,mm]
C4x[m,r] = C4x[m,r] + temp*x[n]
end
end
end
end
end
return nothing
end
function normalEquationsP!(C,Cx,x,A,B)
temp = Matrix{Float64}(undef,M,R)
C4 = reshape(C, M,R,M,R) # writing into original C
C4x = reshape(Cx, M,R)
@fastmath @inbounds @tturbo for m = 1:M
for r = 1:R
for mm = 1:M
for rr = 1:R
for n = 1:N
temp = A[n,r] * B[n,m]
C4[m,r,mm,rr] = C4[m,r,mm,rr] + temp * A[n,rr] * B[n,mm]
C4x[m,r] = C4x[m,r] + temp*x[n]
end
end
end
end
end
return nothing
end
function normalEquationsTullio!(C,Cx,x,A,B)
C4 = reshape(C, M,R,M,R) # writing into original C
Cx4 = reshape(Cx, M,R)
@tullio C4[j,i,l,k] = A[n,i] * B[n,j] * A[n,k] * B[n,l]
@tullio Cx4[j,i] = A[n,i] * B[n,j] * x[n]
return nothing
end
function normalEquationsMatlab(CC,Cx,A,B,x)
batchSize = 10000;
for n = 1:batchSize:N
idx = min(n+batchSize-1,N);
temp = repeat(A[n:idx,:],1,R).*kron(B[n:idx,:], ones(1, M));
CC = CC+temp'*temp;
Cx = Cx+temp'*x[n:idx];
end
return nothing
end
function normalEquationsMatlabJulia(CC,Cx,A,B,x)
batchSize = 10000;
stencil = ones(1, M)
@fastmath @inbounds @views for n = 1:batchSize:N
idx = min(n+batchSize-1,N)
@turbo temp = repeat(A[n:idx,:],1,R).*kron(B[n:idx,:],stencil )
@turbo CC .= CC.+temp'*temp
@turbo Cx .= Cx.+temp'*x[n:idx]
end
return nothing
end
@btime normal!(C,Cx,x,A,B);
@btime normalTurbo!(C,Cx,x,A,B);
@btime normalTturbo!(C,Cx,x,A,B);
@btime normalEquations!(C,Cx,x,B,A);
@btime normalEquationsP!(C,Cx,x,B,A);
@btime normalEquationsTullio!(C,Cx,x,B,A);
@btime normalEquationsMatlab(C,Cx,A,B,x);
@btime normalEquationsMatlabJulia(C,Cx,A,B,x);