Julia’s type system for linear algebra is quite sophisticated. I have a case where, after some reshape and transpose operations, I obtain two matrices A,B for which for some reason, A*B falls back to the slow generic_matmul instead of using BLAS.
Given two arrays A, B (which are matrices or vectors), what determines if the product A*B will be carried out by BLAS, instead of the fallback generic_matmul? Is there a way to have something like a predicate hitsblas(A, B)? That would be very useful debugging this situation.
using Cassette
using LinearAlgebra
struct CallsGemm end
Cassette.@context CallsGemmCtx
function Cassette.overdub(::CallsGemmCtx, ::typeof(LinearAlgebra.BLAS.gemm!), ::Any...)
throw(CallsGemm())
end
function calls_gemm(f, args...)
try
Cassette.overdub(CallsGemmCtx(), f, args...)
catch e
if e isa CallsGemm
return true
end
end
return false
end
@show calls_gemm(*, rand(3, 3), rand(3, 3))
@show calls_gemm(*, rand(10, 10), rand(10, 10))
using StaticArrays
@show calls_gemm(*, rand(SMatrix{3, 3}), rand(SMatrix{3, 3}))