How to tell if A*B hits BLAS?

Julia’s type system for linear algebra is quite sophisticated. I have a case where, after some reshape and transpose operations, I obtain two matrices A,B for which for some reason, A*B falls back to the slow generic_matmul instead of using BLAS.

Given two arrays A, B (which are matrices or vectors), what determines if the product A*B will be carried out by BLAS, instead of the fallback generic_matmul? Is there a way to have something like a predicate hitsblas(A, B)? That would be very useful debugging this situation.

1 Like

How about something like

using Cassette
using LinearAlgebra

struct CallsGemm end

Cassette.@context CallsGemmCtx

function Cassette.overdub(::CallsGemmCtx, ::typeof(LinearAlgebra.BLAS.gemm!), ::Any...)
    throw(CallsGemm())
end

function calls_gemm(f, args...)
    try
        Cassette.overdub(CallsGemmCtx(), f, args...)
    catch e
        if e isa CallsGemm
            return true
        end
    end
    return false
end

@show calls_gemm(*, rand(3, 3), rand(3, 3))
@show calls_gemm(*, rand(10, 10), rand(10, 10))

using StaticArrays
@show calls_gemm(*, rand(SMatrix{3, 3}), rand(SMatrix{3, 3}))

Output:

calls_gemm(*, rand(3, 3), rand(3, 3)) = false
calls_gemm(*, rand(10, 10), rand(10, 10)) = true
calls_gemm(*, rand(SMatrix{3, 3}), rand(SMatrix{3, 3})) = false
9 Likes

Thanks!