Hi all,
I want to build a function that takes several tensors, contracts them, and outputs a tensor as a final result. I was using TensorOperations.jl as a draft, but I felt I could improve the performance making the contractions myself calling BLAS. The function and benchmarks compared with TensorOperations are here:
using LinearAlgebra, TensorOperations, BenchmarkTools
function BL_prop_right3(L::Array{T, 3}, A::Array{T, 3},
W::Array{T, 4}, B::Array{T, 3}) where T<:Number
m = size(A, 1)
d = size(A, 2)
w = size(L, 2)
A = reshape(A, m, d*m)
L = reshape(L, m, w*m)
T1 = BLAS.gemm('T', 'N', L, A)
T1 = reshape(T1, w, m, d, m)
T1 = permutedims(T1, (1, 3, 2, 4))
T1 = reshape(T1, w*d, m^2)
W = reshape(W, w*d, d*w)
T2 = BLAS.gemm('T', 'N', T1, W)
T2 = reshape(T2, m, m, d, w)
T2 = permutedims(T2, (1, 3, 2, 4))
T2 = reshape(T2, m*d, m*w)
B = reshape(B, m*d, m)
Lnew = BLAS.gemm('T', 'N', T2, B)
Lnew = reshape(Lnew, m, w, m)
return Lnew
end
function tensorops(L::Array{T, 3}, A::Array{T, 3},
W::Array{T, 4}, B::Array{T, 3}) where T<:Number
@tensor Lnew[r1, r2, r3] := (L[l1, l2, l3]*A[l1, s1, r1]
*W[l2, s1, s2, r2]*B[l3, s2, r3])
return Lnew
end
function tensorops_slices(L::Array{T, 3}, A::Array{T, 3},
W::Array{T, 4}, B::Array{T, 3}) where T<:Number
@tensor begin
L1[l1, l2, s2, c] := L[l1, l2, l3]*B[l3, s2, c]
L2[l1, s1, b, c] := L1[l1, l2, s2, c]*W[l2, s1, s2, b]
Lnew[a, b, c] := L2[l1, s1, b, c]*A[l1, s1, a]
end
return Lnew
end
function run_benchmarks(m)
w = 100
d = 2
L = rand(m, w, m)
A = rand(m, d, m)
W = rand(w, d, d, w)
B = rand(m, d, m)
# Final tensor's size: (m, w, m).
L2 = tensorops_slices(L, A, W, B)
L3 = BL_prop_right3(L, A, W, B)
L4 = tensorops(L, A, W, B)
@btime BL_prop_right3($L, $A, $W, $B)
@btime tensorops_slices($L, $A, $W, $B)
@btime tensorops($L, $A, $W, $B)
return
end
run_benchmarks(256)
I get for m=256
:
440.183 ms (38 allocations: 450.00 MiB)
528.910 ms (122 allocations: 250.02 MiB)
523.354 ms (220 allocations: 50.03 MiB)
The time ratio of the TensorOperations functions against BLAS goes like
Do you think I could make the BLAS function even faster?