Hi, I tried to consider tensor contraction A[a,b] B[b,c,d,e] = C[a,c,c,e]
, where B[b,c,d,e] = B[b,c,e,d]
. I am not sure if my following julia code is right
using LinearAlgebra
using BenchmarkTools, TensorOperations
n = 60
A = rand(n, n)
B_4d = zeros(n, n, n, n)
C_new = zeros(n, n, n, n)
for i = 1:n
for j = 1:n
for k = 1:n
for l = 1:k
val = rand()
B_4d[i,j,l,k] = B_4d[i,j,k,l] = val
end
end
end
end
@tensor C_new[a,c,d,e] := A[a,b] * B_4d[b,c,d,e]
println(C_new[1,1,3,2], C_new[1,1,2,3] )
@btime @tensor C_new[a,c,d,e] := A[a,b] * B_4d[b,c,d,e]
@btime begin
for d = 1:n
for e = 1:d
C_new[:,:,d,e] = LinearAlgebra.BLAS.gemm('N', 'N', A, B_4d[:,:,d,e])
C_new[:,:,e,d] .= C_new[:,:,d,e]
end
end
end
println(C_new[1,1,3,2], C_new[1,1,2,3] )
C_new = zeros(n, n, n, n)
@btime begin
for d = 1:n
for e = 1:d
Cview = @view C_new[:,:,d,e]
Bview = @view B_4d[:,:,d,e]
@tensor Cview[g,h] := A[g,i] *Bview[i,h]
C_new[:,:,d,e] .= Cview[:,:]
C_new[:,:,e,d] .= C_new[:,:,d,e]
end
end
end
println(C_new[1,1,3,2], C_new[1,1,2,3] )
I got
173.423 ms
115.171 ms
99.679 ms
If I use Fortran, I can get (by input dimension 60, gfortran
o3 with openblas
)
67.2 ms
78.8ms
40.2ms
the utilization of symmetry can provide ~96 % speed up and much faster than julia
)
Fortran code (refereed from stack overflow)
Program test
integer, parameter :: dp = selected_real_kind(15, 307)
Real(dp), Dimension(:, :), Allocatable :: a
Real(dp), Dimension(:, :, :, :), Allocatable :: b0, b
Real(dp), Dimension(:, :, :, :), Allocatable :: c1, c2, c3, c4
Integer :: na, nb, nc, nd, ne, m, m_iter
Integer :: la, lb, lc, ld, le
Integer :: start, finish, rate
real(dp) :: sum_time
Write (*,*) 'na, nb, nc, nd, ne ?'
Read (*,*) na, nb, nc, nd, ne
Allocate( a ( 1:na, 1:nb ) )
Allocate( b0 ( 1:nb, 1:nc, 1:nd, 1:ne ) )
Allocate( b ( 1:nb, 1:nc, 1:nd, 1:ne ) )
Allocate( c1( 1:na, 1:nc, 1:nd, 1:ne ) )
Allocate( c2( 1:na, 1:nc, 1:nd, 1:ne ) )
Allocate( c3( 1:na, 1:nc, 1:nd, 1:ne ) )
Allocate( c4( 1:na, 1:nc, 1:nd, 1:ne ) )
Call Random_number( a )
Call Random_number( b0 )
c1 = 0.0_dp
c2 = 0.0_dp
c3 = 0.0_dp
c4 = 0.0_dp
m_iter = 5
write (*,*) 'm_iter average', m_iter
b = b0
do ld = 1, nd
do le = 1, ne
b(:, :, ld, le) = ( b0(:, :, ld, le) + b0(:, :, le, ld) )*0.5
end do
end do
sum_time = 0.0
do m = 1, m_iter
c1 = 0.0_dp
Call System_clock(start, rate)
do le = 1, ne
do ld = 1, nd
do lc = 1, nc
do lb = 1, nb
do la = 1, na
c1(la, lc, ld, le) = c1(la, lc, ld, le) + a(la, lb) * b(lb, lc, ld, le)
end do
end do
end do
end do
end do
Call System_clock(finish, rate)
sum_time = sum_time + Real(finish - start, dp) / rate
end do
Write (*,*) 'Time for naive loop edcba', sum_time / m_iter
sum_time = 0.0
do m = 1, m_iter
Call System_clock(start, rate)
Call dgemm( 'N', 'N', na, nc * nd * ne , nb, 1.0_dp, a , Size( a , Dim = 1 ), &
b , Size( b , Dim = 1 ), &
0.0_dp, c2, Size( c2, Dim = 1 ) )
Call System_clock(finish, rate)
sum_time = sum_time + Real(finish - start, dp) / rate
end do
Write (*,*) 'Time for straight dgemm ', sum_time / m_iter
sum_time = 0.0
do m = 1, m_iter
do le = 1, ne
do ld = 1, nd
Call System_clock(start, rate)
Call dgemm( 'N', 'N', na, nb, nb, 1.0_dp, a , Size( a , Dim = 1 ), &
b(:,:,ld, le) , Size( b , Dim = 1 ), &
0.0_dp, c3(:,:,ld, le), Size( c3, Dim = 1 ) )
Call System_clock(finish, rate)
sum_time = sum_time + Real(finish - start, dp) / rate
end do
end do
end do
Write (*,*) 'Time for loop dgemm-2', sum_time / m_iter
sum_time = 0.0
do m = 1, m_iter
do le = 1, ne
do ld = 1, le
Call System_clock(start, rate)
Call dgemm( 'N', 'N', na, nb, nb, 1.0_dp, a , Size( a , Dim = 1 ), &
b(:,:,ld, le) , Size( b , Dim = 1 ), &
0.0_dp, c4(:,:,ld, le), Size( c4, Dim = 1 ) )
Call System_clock(finish, rate)
sum_time = sum_time + Real(finish - start, dp) / rate
c4(:,:,le, ld) = c4(:,:,ld, le)
end do
end do
end do
Write (*,*) 'Time for loop dgemm-symmetry', sum_time / m_iter
do la = 1, na
do lc = 1, nc
do ld = 1, nd
do le = 1, ne
if ( dabs(c2(la,lc,ld,le) - c1(la,lc,ld,le)) > 1.e-6 ) then
write (*,*) '!!! c2', la,lc,ld,le, c2(la,lc,ld,le), c1(la,lc,ld,le)
endif
if ( dabs(c3(la,lc,ld,le) - c1(la,lc,ld,le)) > 1.e-6 ) then
write (*,*) '!!! c3', la,lc,ld,le, c3(la,lc,ld,le), c1(la,lc,ld,le)
endif
if ( dabs(c4(la,lc,ld,le) - c1(la,lc,ld,le)) > 1.e-6 ) then
write (*,*) '!!! c4', la,lc,ld,le, c4(la,lc,ld,le), c1(la,lc,ld,le)
endif
enddo
enddo
enddo
enddo
End