Help me get rid of extra allocations in this generic batching iterator

I’m trying to make the routines in BatchedRoutines more generic, so I could wrap other BLAS routines to use pointer offsets easier. However, when I wrap the iteration to a function, the batched_gemm routine become type unstable, and have a lot allocations

The original implementation

using LinearAlgebra
import LinearAlgebra: BLAS

for (gemm, elty) in
        ((:dgemm_,:Float64),
         (:sgemm_,:Float32),
         (:zgemm_,:ComplexF64),
         (:cgemm_,:ComplexF32))
    @eval begin
        function batched_gemm!(transA::AbstractChar, transB::AbstractChar, alpha::($elty), A::AbstractArray{$elty, 3}, B::AbstractArray{$elty, 3}, beta::($elty), C::AbstractArray{$elty, 3})
            @assert !BLAS.has_offset_axes(A, B, C)
            @assert size(A, 3) == size(B, 3) == size(C, 3) "batch size mismatch"
            m = size(A, transA == 'N' ? 1 : 2)
            ka = size(A, transA == 'N' ? 2 : 1)
            kb = size(B, transB == 'N' ? 1 : 2)
            n = size(B, transB == 'N' ? 2 : 1)
            if ka != kb || m != size(C,1) || n != size(C,2)
                throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))"))
            end
            BLAS.chkstride1(A)
            BLAS.chkstride1(B)
            BLAS.chkstride1(C)

            ptrA = Base.unsafe_convert(Ptr{$elty}, A)
            ptrB = Base.unsafe_convert(Ptr{$elty}, B)
            ptrC = Base.unsafe_convert(Ptr{$elty}, C)

            for k in 1:size(A, 3)
                ccall((BLAS.@blasfunc($gemm), BLAS.libblas), Cvoid,
                    (Ref{UInt8}, Ref{UInt8}, Ref{BLAS.BlasInt}, Ref{BLAS.BlasInt},
                     Ref{BLAS.BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BLAS.BlasInt},
                     Ptr{$elty}, Ref{BLAS.BlasInt}, Ref{$elty}, Ptr{$elty},
                     Ref{BLAS.BlasInt}),
                     transA, transB, m, n,
                     ka, alpha, ptrA, max(1,stride(A,2)),
                     ptrB, max(1,stride(B,2)), beta, ptrC,
                     max(1,stride(C,2)))

                ptrA += size(A, 1) * size(A, 2) * sizeof($elty)
                ptrB += size(B, 1) * size(B, 2) * sizeof($elty)
                ptrC += size(C, 1) * size(C, 2) * sizeof($elty)
            end

            C
        end
        function batched_gemm(transA::AbstractChar, transB::AbstractChar, alpha::($elty), A::AbstractArray{$elty, 3}, B::AbstractArray{$elty, 3})
            batched_gemm!(transA, transB, alpha, A, B, zero($elty), similar(B, $elty, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), size(B, 3))))
        end
        function batched_gemm(transA::AbstractChar, transB::AbstractChar, A::AbstractArray{$elty, 3}, B::AbstractArray{$elty, 3})
            batched_gemm(transA, transB, one($elty), A, B)
        end
    end
end

The more generic implementation

using LinearAlgebra
import LinearAlgebra: BLAS

function iterate_batch(f, ::Val{2}, A, B, C)
    BLAS.chkstride1(A)
    BLAS.chkstride1(B)
    BLAS.chkstride1(C)

    ptrA = Base.unsafe_convert(Ptr{Float64}, A)
    ptrB = Base.unsafe_convert(Ptr{Float64}, B)
    ptrC = Base.unsafe_convert(Ptr{Float64}, C)

    for k in 1:size(A, 3)
        f(ptrA, ptrB, ptrC)
        ptrA += stride(A, 3) * sizeof(Float64)
        ptrB += stride(B, 3) * sizeof(Float64)
        ptrC += stride(C, 3) * sizeof(Float64)
    end
    return C
end

for (gemm, elty) in
    ((:dgemm_,:Float64),
     (:sgemm_,:Float32),
     (:zgemm_,:ComplexF64),
     (:cgemm_,:ComplexF32))
@eval begin
    function batched_gemm!(transA::AbstractChar, transB::AbstractChar, alpha::($elty), A::AbstractArray{$elty, 3}, B::AbstractArray{$elty, 3}, beta::($elty), C::AbstractArray{$elty, 3})
        @assert !BLAS.has_offset_axes(A, B, C)
        @assert size(A, 3) == size(B, 3) == size(C, 3) "batch size mismatch"
        m = size(A, transA == 'N' ? 1 : 2)
        ka = size(A, transA == 'N' ? 2 : 1)
        kb = size(B, transB == 'N' ? 1 : 2)
        n = size(B, transB == 'N' ? 2 : 1)
        if ka != kb || m != size(C,1) || n != size(C,2)
            throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))"))
        end
 
        iterate_batch(Val(2), A, B, C) do ptrA, ptrB, ptrC
            ccall((BLAS.@blasfunc($gemm), BLAS.libblas), Cvoid,
                (Ref{UInt8}, Ref{UInt8}, Ref{BLAS.BlasInt}, Ref{BLAS.BlasInt},
                 Ref{BLAS.BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BLAS.BlasInt},
                 Ptr{$elty}, Ref{BLAS.BlasInt}, Ref{$elty}, Ptr{$elty},
                 Ref{BLAS.BlasInt}),
                 transA, transB, m, n,
                 ka, alpha, ptrA, max(1,stride(A,2)),
                 ptrB, max(1,stride(B,2)), beta, ptrC,
                 max(1,stride(C,2)))
        end
        return C
    end
    function batched_gemm(transA::AbstractChar, transB::AbstractChar, alpha::($elty), A::AbstractArray{$elty, 3}, B::AbstractArray{$elty, 3})
        batched_gemm!(transA, transB, alpha, A, B, zero($elty), similar(B, $elty, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), size(B, 3))))
    end
    function batched_gemm(transA::AbstractChar, transB::AbstractChar, A::AbstractArray{$elty, 3}, B::AbstractArray{$elty, 3})
        batched_gemm(transA, transB, one($elty), A, B)
    end
end
end

And test the @code_warntype

@code_warntype batched_gemm!('N', 'N', 1.0, rand(2, 2, 100), rand(2, 2, 100), 0.0, rand(2, 2, 100))

Is this something related to this ccall? since I notice the cconvert in lowered code has some type instabilities.

It looks like a closure issue. The type instability can be fixed by moving these into iterate_batch’s do-block. Not sure whether this is #15276 or not.

BTW, those pointers in iterate_batch are hard-coded to Float64. I suppose they should be of type $elty just like those in batched_gemm!?