I am trying to devellop a package that takes advantage of the multithreading provided by Polyester.jl to multithread sparse matrix - vec multiplication for matrices in the CSR format, as defined in the package SparseMatricesCSR.jl
Adding multithreading was trivial (just adding a
@batch to the non-threaded implementation of
mul!). The good news, is that by just doing so, we get performance comparable (sometimes slightly better) to the transpose(mat)-vec multiplication provided by MKLSparse.jl (for CSC matrices).
Now, in the same spirit of MKLSparse.jl, I would like for my package to overwrite the non-threaded
mul!defined in SparseMatricesCSR.jl. I was hoping to do so using
@eval. Here is what my code looks like:
module ThreadedSparseCSR using SparseMatricesCSR using Polyester using LinearAlgebra export csr_bmul!, csr_bmul using SparseMatricesCSR: nzrange function csr_bmul!(y::AbstractArray, A::SparseMatrixCSR, x::AbstractVector, alpha::Number, beta::Number) A.n == size(x, 1) || throw(DimensionMismatch()) A.m == size(y, 1) || throw(DimensionMismatch()) o = getoffset(A) @batch for row in 1:size(y, 1) @inbounds begin accu = zero(eltype(y)) for nz in nzrange(A, row) col = A.colval[nz] + o accu += A.nzval[nz]*x[col] end y[row] = alpha*accu + beta*y[row] end end return y end import LinearAlgebra: mul! @eval begin function mul!(y::AbstractArray, A::SparseMatrixCSR, x::AbstractVector, alpha::Number, beta::Number) return csr_bmul!(y, A, x, alpha, beta) end end end
I was hoping that the code block in
@eval, would overwrite the method. However, that is not what I am seeing.
If i test the code:
using LinearAlgebra, SparseArrays, SparseMatricesCSR, BenchmarkTools m, n = 1_000, 1_000 d = 0.01 num_nzs = floor(Int, m*n*d) rows = rand(1:m, num_nzs) cols = rand(1:n, num_nzs) vals = rand(num_nzs) cscA = sparse(rows, cols, vals, m, n) csrA = sparsecsr(rows, cols, vals, m, n) x = rand(n) y = similar(x) @btime mul!(y, csrA, x, true, false); 6.258 μs (0 allocations: 0 bytes) using ThreadedSparseCSR @btime mul!(y, csrA, x, true, false); 6.621 μs (0 allocations: 0 bytes) # no speed up @btime csr_bmul!(y, csrA, x, true, false); 1.655 μs (1 allocation: 96 bytes) # speed up do to multithreading
What am I missing regarding the use of
Thank you for your help!