Here is a βtype stableβ implementation, but if you benchmark, it doesnβt really matter for the runtime. The time is spent on matrix-matrix product. But if type stability is important for some other parts, then it may be useful:
function tensordot(A::AbstractArray{T1,N1}, B::AbstractArray{T2,N2},::Val{n}) where {T1,T2,N1,N2,n}
s1=ntuple(i->size(A,i),Val(N1-n))
s2=ntuple(i->size(A,N1-i+1),Val(n))
s3=ntuple(i->size(B,i),Val(n))
s4=ntuple(i->size(B,i+n),Val(N2-n))
s5=ntuple(i->( i<=N1-n ? s1[i] : s4[i-(N1-n)]),Val(N1+N2-2n))
Amat = reshape(A, prod(s1), prod(s2))
Bmat = reshape(B, prod(s3), prod(s4))
Cmat = Amat * Bmat
C = reshape(Cmat,s5)
return C
end
julia> A=rand(10,12,13,14);
julia> B=rand(13,14,15,16,17);
julia> tensordot(A, B, Val(2))==tensordot(A,B,2)
true
julia> @btime tensordot($A,$B,Val(2));
958.343 ΞΌs (8 allocations: 3.74 MiB)
julia> @btime tensordot($A,$B,2);
965.924 ΞΌs (23 allocations: 3.74 MiB)
julia> @code_warntype tensordot(A,B,Val(2))
Variables
#self#::Core.Compiler.Const(tensordot, false)
A::Array{Float64,4}
B::Array{Float64,5}
#unused#::Core.Compiler.Const(Val{2}(), false)
#273::var"#273#278"{Array{Float64,4}}
#274::var"#274#279"{4,Array{Float64,4}}
#275::var"#275#280"{Array{Float64,5}}
#276::var"#276#281"{2,Array{Float64,5}}
#277::var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}
s1::Tuple{Int64,Int64}
s2::Tuple{Int64,Int64}
s3::Tuple{Int64,Int64}
s4::Tuple{Int64,Int64,Int64}
s5::NTuple{5,Int64}
Amat::Array{Float64,2}
Bmat::Array{Float64,2}
Cmat::Array{Float64,2}
C::Array{Float64,5}
Body::Array{Float64,5}
1 β %1 = Main.:(var"#273#278")::Core.Compiler.Const(var"#273#278", false)
β %2 = Core.typeof(A)::Core.Compiler.Const(Array{Float64,4}, false)
β %3 = Core.apply_type(%1, %2)::Core.Compiler.Const(var"#273#278"{Array{Float64,4}}, false)
β (#273 = %new(%3, A))
β %5 = #273::var"#273#278"{Array{Float64,4}}
β %6 = ($(Expr(:static_parameter, 3)) - $(Expr(:static_parameter, 5)))::Core.Compiler.Const(2, false)
β %7 = Main.Val(%6)::Core.Compiler.Const(Val{2}(), true)
β (s1 = Main.ntuple(%5, %7))
β %9 = Main.:(var"#274#279")::Core.Compiler.Const(var"#274#279", false)
β %10 = $(Expr(:static_parameter, 3))::Core.Compiler.Const(4, false)
β %11 = Core.typeof(A)::Core.Compiler.Const(Array{Float64,4}, false)
β %12 = Core.apply_type(%9, %10, %11)::Core.Compiler.Const(var"#274#279"{4,Array{Float64,4}}, false)
β (#274 = %new(%12, A))
β %14 = #274::var"#274#279"{4,Array{Float64,4}}
β %15 = Main.Val($(Expr(:static_parameter, 5)))::Core.Compiler.Const(Val{2}(), true)
β (s2 = Main.ntuple(%14, %15))
β %17 = Main.:(var"#275#280")::Core.Compiler.Const(var"#275#280", false)
β %18 = Core.typeof(B)::Core.Compiler.Const(Array{Float64,5}, false)
β %19 = Core.apply_type(%17, %18)::Core.Compiler.Const(var"#275#280"{Array{Float64,5}}, false)
β (#275 = %new(%19, B))
β %21 = #275::var"#275#280"{Array{Float64,5}}
β %22 = Main.Val($(Expr(:static_parameter, 5)))::Core.Compiler.Const(Val{2}(), true)
β (s3 = Main.ntuple(%21, %22))
β %24 = Main.:(var"#276#281")::Core.Compiler.Const(var"#276#281", false)
β %25 = $(Expr(:static_parameter, 5))::Core.Compiler.Const(2, false)
β %26 = Core.typeof(B)::Core.Compiler.Const(Array{Float64,5}, false)
β %27 = Core.apply_type(%24, %25, %26)::Core.Compiler.Const(var"#276#281"{2,Array{Float64,5}}, false)
β (#276 = %new(%27, B))
β %29 = #276::var"#276#281"{2,Array{Float64,5}}
β %30 = ($(Expr(:static_parameter, 4)) - $(Expr(:static_parameter, 5)))::Core.Compiler.Const(3, false)
β %31 = Main.Val(%30)::Core.Compiler.Const(Val{3}(), true)
β (s4 = Main.ntuple(%29, %31))
β %33 = Main.:(var"#277#282")::Core.Compiler.Const(var"#277#282", false)
β %34 = $(Expr(:static_parameter, 3))::Core.Compiler.Const(4, false)
β %35 = $(Expr(:static_parameter, 5))::Core.Compiler.Const(2, false)
β %36 = Core.typeof(s1)::Core.Compiler.Const(Tuple{Int64,Int64}, false)
β %37 = Core.typeof(s4)::Core.Compiler.Const(Tuple{Int64,Int64,Int64}, false)
β %38 = Core.apply_type(%33, %34, %35, %36, %37)::Core.Compiler.Const(var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}, false)
β %39 = s1::Tuple{Int64,Int64}
β (#277 = %new(%38, %39, s4))
β %41 = #277::var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}
β %42 = ($(Expr(:static_parameter, 3)) + $(Expr(:static_parameter, 4)))::Core.Compiler.Const(9, false)
β %43 = (2 * $(Expr(:static_parameter, 5)))::Core.Compiler.Const(4, false)
β %44 = (%42 - %43)::Core.Compiler.Const(5, false)
β (s5 = Main.ntuple(%41, %44))
β %46 = Main.prod(s1)::Int64
β %47 = Main.prod(s2)::Int64
β (Amat = Main.reshape(A, %46, %47))
β %49 = Main.prod(s3)::Int64
β %50 = Main.prod(s4)::Int64
β (Bmat = Main.reshape(B, %49, %50))
β (Cmat = Amat * Bmat)
β (C = Main.reshape(Cmat, s5))
βββ return C
The trick is to make the sizes (N1,N2,n
) available at compile time, Val
and ntuple
β¦
Cheers!