Hi!
I’m trying to implement the ?gemm_batch functions from the Intel MKL library into Julia. These functions performs matrix-matrix operations with groups of matrices, processing a number of groups at once.
Just for testing purposes, I did the following implementation assuming there’s only one group of matrices:
using LinearAlgebra: BlasInt
using BenchmarkTools
using MKL_jll: libmkl_rt
for (C, func) in ((:ComplexF64, :zgemm_batch),
(:Float64, :dgemm_batch))
@eval begin
function gemm_batch!(A_array::Vector{Matrix{$C}}, B_array::Vector{Matrix{$C}}, C_array::Vector{Matrix{$C}}, α_array::Vector{$C},
β_array::Vector{$C}, transa_A::Vector{Char}, transa_B::Vector{Char})
# we assume that all matrices will be of equal size
rows_A, cols_A = size(A_array[1])
rows_B, cols_B = size(B_array[1])
rows_C, cols_C = size(C_array[1])
# arrays needed for ?gemm_batch
m_array = [rows_A]
n_array = [cols_B]
k_array = [cols_A]
transa_A[1] == 'N' ? lda_array = [max(1,rows_A)] : lda_array = [max(1,cols_A)]
transa_B[1] == 'N' ? ldb_array = [max(1,cols_B)] : ldb_array = [max(1,rows_B)]
ldc_array = [max(1,rows_A)]
group_size = [length(A_array)]
ccall(($(string(func)),libmkl_rt), Cvoid,
(Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$C}, Ptr{$C}, Ptr{BlasInt}, Ptr{$C}, Ptr{BlasInt},
Ptr{$C}, Ptr{$C}, Ptr{BlasInt}, Ref{BlasInt}, Ptr{BlasInt}),
transa_A, transa_B, m_array, n_array, k_array,
α_array, A_array, lda_array, B_array, ldb_array,
β_array, C_array, ldc_array, 1, group_size)
return nothing
end
end
end
Then, If I call those functions,
Nb = 10
Nk = 5
A = [rand(Nk,Nk) + 1im .* rand(Nk,Nk) for x in 1:Nb]
B = (1.0 + 0.0im) .* A
C = (0.0 + 0.0im) .* A
α = [1.0 + 0.0im]
β = [0.0 + 0.0im]
transa = ['N']
gemm_batch!(A, B, C, α, β, transa, transa)
I get the following error
Intel MKL ERROR: Parameter 1 was incorrect on entry to ZGEMM_BATCH .
I suppose this error is related to how Julia does conversion from Vector{Char}
into Ptr{UInt8}
but since the error log doesn’t give a lot of information I’m not sure if that is really the cause of the error. I’ve tried other conversions such as Ptr{Cchar}
and the same error appears.
How can I solve this error?
PS: To see if I was making a mistake when loading the MKL libraries or something along those lines, I’ve tried an implementation of the ?gemm functions and those work perfectly fine.