Speeding up matrix multiplication with a subset of matrix's rows

Okay, I think I’ve got it. I took @under-Peter’s idea and substituted the memory copy + vectorization with dot with a hand-written dot product implementation:

using LinearAlgebra

# Disable BLAS multithreading for a fairer comparison
BLAS.set_num_threads(1)

"""
Compute a matrix-vector product `W[rows,:] * x`. Assume that `Wt == W'`.
"""
function partial_matvec(Wt::AbstractMatrix{T}, x, rows) where {T}
	output = zeros(T, length(rows))

	@inbounds for ii = 1:length(rows)
		row = rows[ii]

		# Compute dot product of the current row of W with x
		result = T(0)
		@simd for col = 1:length(x)
			result += Wt[col,row] * x[col]
		end

		output[ii] = result
	end

	return output
end

# Check that partial_matvec and bar work correctly
W = randn(256, 256);
Wt = Matrix(W');
x = randn(256);
rows = findall(rand(256) .≥ 0.5)
@assert (W*x)[rows] ≈ partial_matvec(Wt,x,rows)

Here are results for a 256\times 256 matrix, dropping \approx 1/2 of the rows of W:

julia> W = randn(256, 256); Wt = Matrix(W');

julia> @benchmark W*x setup=(x=randn(256))
BenchmarkTools.Trial: 
  memory estimate:  2.13 KiB
  allocs estimate:  1
  --------------
  minimum time:     10.726 μs (0.00% GC)
  median time:      11.018 μs (0.00% GC)
  mean time:        14.682 μs (24.39% GC)
  maximum time:     35.868 ms (99.85% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark partial_matvec(Wt, x, rows) setup=(x=randn(256); rows=findall(rand(256) .≥ 0.5))
BenchmarkTools.Trial: 
  memory estimate:  896 bytes
  allocs estimate:  1
  --------------
  minimum time:     4.155 μs (0.00% GC)
  median time:      5.827 μs (0.00% GC)
  mean time:        5.858 μs (0.00% GC)
  maximum time:     11.197 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     7

And here’s output from a run with a 1024\times 1024 matrix:

julia> W = randn(1024, 1024); Wt = Matrix(W');

julia> @benchmark W*x setup=(x=randn(1024))
BenchmarkTools.Trial: 
  memory estimate:  8.13 KiB
  allocs estimate:  1
  --------------
  minimum time:     303.350 μs (0.00% GC)
  median time:      314.284 μs (0.00% GC)
  mean time:        318.422 μs (0.02% GC)
  maximum time:     894.475 μs (60.18% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark partial_matvec(Wt,x,rows) setup=(x=randn(1024);rows=findall(rand(1024) .≥ 0.5))
BenchmarkTools.Trial: 
  memory estimate:  3.69 KiB
  allocs estimate:  1
  --------------
  minimum time:     172.821 μs (0.00% GC)
  median time:      194.660 μs (0.00% GC)
  mean time:        195.771 μs (0.03% GC)
  maximum time:     974.548 μs (63.24% GC)
  --------------
  samples:          10000
  evals/sample:     1

I ran into all kinds of interesting problems with this one, e.g. if I set result = 0 instead of result = T(0), performance became roughly 4x to 8x worse. Ditto if I tried to replace updates to result in the inner loop by updates to output[ii], e.g. output[ii] += Wt[col,row] * x[col].

Regardless, it seems like it’s working now. Thank you to everyone who contributed, I appreciate the help! :smile:

2 Likes