Hi all!
I am currently considering an assembling problem. Given is a dense matrix A and a sparse matrix B and the goal is to efficiently compute the matrix H = B^\top (A \otimes A) B. If one expresses the matrix B as B = (\mathrm{vec}(B_1), \dots, \mathrm{vec}(B_n)) with square matrices B_j and where \mathrm{vec}(\cdot) denotes the vectorization of a matrix, then one finds that the (i,j) th entry of H is given by \mathrm{trace}(AB_j A B_i). Matrices of this (and similar) structure appear in interior point methods for solving linear semi-definite programs.
I have the following Matlab code with which I am quite happy as it is reasonably fast and clean:
n = 30;
A = ones(n, n);
B = sprand(n^2, n^2/2, 0.1);
AssemblePartHess(A, B);
function C = AssembleHessian(A, B)
% Number of rows of the matrix A
nA = size(A, 1);
% Build the matrix [A*B1, ..., A*Bn] which is dense in general.
C = reshape(B, nA, []);
C = A * C;
% Generate a 3D array in which C(:, :, j) = A*Bj.
C = reshape(C, nA, nA, []);
% Multiply each of these slices by A from the right.
C = pagemtimes(C, A);
% This reshape results in the matrix [vec(A*B1*A), ..., vec(A*B2*A)]
C = reshape(C, nA^2, []);
% Last multiplication results in desired matrix and is not too
% expensive as B is sparse.
C = B' * C;
I tried to achieve the same or better in Julia, but so far this didn’t work out as I am not that familiar with it. Here is what I currently have:
using LinearAlgebra
using SparseArrays
using BenchmarkTools
n = 30;
A = ones(n, n);
B = sprand(n^2, Int(n^2/2), 0.1);
function fun(A::Matrix{Float64}, B::SparseMatrixCSC{Float64, Int64})
nA = size(A, 1)
C = SparseMatrixCSC(reshape(B, nA, :))
C = A * C
for i in axes(B, 2)
C[:,1+ (i-1)*nA: i*nA] = view(C, :, 1+ (i-1)*nA: i*nA)* A
C = reshape(C, nA^2, :)
C = B' * C
return C
@btime fun($A, $B);
For me this result in
29.840 ms (478 allocations: 9.48 MiB)
which is, unfortunately, about 3 times slower than the Matlab function.
My output of versioninfo()
Julia Version 1.9.1
Commit 147bdf428c (2023-06-07 08:27 UTC)
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 8 × Intel(R) Core(TM) i5-8265U CPU @ 1.60GHz
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
Threads: 8 on 8 virtual cores
Any suggestions are welcome. Thanks!