I am trying to devellop a package that takes advantage of the multithreading provided by Polyester.jl to multithread sparse matrix - vec multiplication for matrices in the CSR format, as defined in the package SparseMatricesCSR.jl
Adding multithreading was trivial (just adding a @batch
to the non-threaded implementation of mul!
). The good news, is that by just doing so, we get performance comparable (sometimes slightly better) to the transpose(mat)-vec multiplication provided by MKLSparse.jl (for CSC matrices).
Now, in the same spirit of MKLSparse.jl, I would like for my package to overwrite the non-threaded mul!
defined in SparseMatricesCSR.jl. I was hoping to do so using @eval
. Here is what my code looks like:
module ThreadedSparseCSR
using SparseMatricesCSR
using Polyester
using LinearAlgebra
export csr_bmul!, csr_bmul
using SparseMatricesCSR: nzrange
function csr_bmul!(y::AbstractArray, A::SparseMatrixCSR, x::AbstractVector, alpha::Number, beta::Number)
A.n == size(x, 1) || throw(DimensionMismatch())
A.m == size(y, 1) || throw(DimensionMismatch())
o = getoffset(A)
@batch for row in 1:size(y, 1)
@inbounds begin
accu = zero(eltype(y))
for nz in nzrange(A, row)
col = A.colval[nz] + o
accu += A.nzval[nz]*x[col]
end
y[row] = alpha*accu + beta*y[row]
end
end
return y
end
import LinearAlgebra: mul!
@eval begin
function mul!(y::AbstractArray, A::SparseMatrixCSR, x::AbstractVector, alpha::Number, beta::Number)
return csr_bmul!(y, A, x, alpha, beta)
end
end
end
I was hoping that the code block in @eval
, would overwrite the method. However, that is not what I am seeing.
If i test the code:
using LinearAlgebra, SparseArrays, SparseMatricesCSR, BenchmarkTools
m, n = 1_000, 1_000
d = 0.01
num_nzs = floor(Int, m*n*d)
rows = rand(1:m, num_nzs)
cols = rand(1:n, num_nzs)
vals = rand(num_nzs)
cscA = sparse(rows, cols, vals, m, n)
csrA = sparsecsr(rows, cols, vals, m, n)
x = rand(n)
y = similar(x)
@btime mul!(y, csrA, x, true, false);
6.258 μs (0 allocations: 0 bytes)
using ThreadedSparseCSR
@btime mul!(y, csrA, x, true, false);
6.621 μs (0 allocations: 0 bytes) # no speed up
@btime csr_bmul!(y, csrA, x, true, false);
1.655 μs (1 allocation: 96 bytes) # speed up do to multithreading
What am I missing regarding the use of @eval
?
Thank you for your help!
EDIT: typo