Thank you for your comment @Ralph_Smith.
First of all, regarding MKLSparse.jl, I got the impression that MKLSparse was doing type piracy from this comment ANN: MKLSparse - #9 by kristoffer.carlsson
and assumed that type pyracy = method redefinition.
Back to my code, indeed the problem was with the y::AbstractArray
declaration. As you pointed out, that made the method less specific than the
mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector, α::Number, β::Number)
exported by SpareceMatricesCSR.jl. So the new method was never called.
If I change the function declaration to
mul!(y::AbstractVector, A::SparseMatrixCSR, x::AbstractVector, alpha::Number, beta::Number)
I can overwrite the method and then the redefined method can be called.
As a matter of fact, I have now managed to implement a function that overwrites mul!
only if the user wants to.
With the new code:
module ThreadedSparseCSR
using SparseMatricesCSR
using Polyester
using LinearAlgebra
export csr_bmul!
#export csr_bmul
#include("matmul.jl")
#include("multithread_mul.jl")
using SparseMatricesCSR: nzrange
function csr_bmul!(y::AbstractVector, A::SparseMatrixCSR, x::AbstractVector, alpha::Number, beta::Number)
A.n == size(x, 1) || throw(DimensionMismatch())
A.m == size(y, 1) || throw(DimensionMismatch())
o = getoffset(A)
@batch for row in 1:size(y, 1)
@inbounds begin
accu = zero(eltype(y))
for nz in nzrange(A, row)
col = A.colval[nz] + o
accu += A.nzval[nz]*x[col]
end
y[row] = alpha*accu + beta*y[row]
end
end
return y
end
import LinearAlgebra: mul!
function multithread_matmul()
@eval function mul!(y::AbstractVector, A::SparseMatrixCSR, x::AbstractVector, alpha::Number, beta::Number)
return csr_bmul!(y, A, x, alpha, beta)
end
end
end
If the user does:
using LinearAlgebra, SparseArrays, SparseMatricesCSR, BenchmarkTools
m, n = 1_000, 1_000
d = 0.01
num_nzs = floor(Int, m*n*d)
rows = rand(1:m, num_nzs)
cols = rand(1:n, num_nzs)
vals = rand(num_nzs)
cscA = sparse(rows, cols, vals, m, n)
csrA = sparsecsr(rows, cols, vals, m, n)
x = rand(n)
y = similar(x)
@btime mul!($y, $csrA, $x, true, false)
6.621 μs (0 allocations: 0 bytes)
using ThreadedSparseCSR
@btime mul!($y, $csrA, $x, true, false)
6.719 μs (0 allocations: 0 bytes) # still no threading
ThreadedSparseCSR.multithread_matmul()
@btime mul!($y, $csrA, $x, true, false)
1.606 μs (1 allocation: 96 bytes) # mul! is now threaded!
Thank you for the help!