Fast type-stable tensor products

Let A and B be two multi-dimensional arrays. I want to form a product of them that contracts the last n dimensions of A with the first n dimensions of B. Calling this product C, one way to do it is with the following code:


function tensordot(A::AbstractArray, B::AbstractArray, n::Int)
    Amat = reshape(A, prod(size(A)[1:end-n]), prod(size(A)[end-n+1:end]))
    Bmat = reshape(B, prod(size(B)[1:n]), prod(size(B)[n+1:end]))
    Cmat = Amat * Bmat
    C = reshape(Cmat, size(A)[1:end-n]..., size(B)[n+1:end]...)
    return C
end

# example
A = randn(10,2,5,7);
B = randn(5,7,3,4,7); 
tensordot(A, B, 2)

Is this a fast way to do this kind of products in general? Or is there a better way?

Not really, it’s a single call to *.

If the arrays are small, then you may see some advantage to calculating the sizes for reshape using ntupleisms, and not making n global.

Not sure what you mean here.

So this is as fast as possible?

I mean you can squeeze out a ΞΌs or so by being careful when calculating these sizes. Whether this matters at all will depend on how big the arrays are:

julia> @btime prod(size($A)[end-2:end])
  335.348 ns (2 allocations: 144 bytes)
70

julia> @btime prod(ntuple(d -> size($A,ndims($A)-d+1), 3))
  1.421 ns (0 allocations: 0 bytes)
70

Strange. Raised an issue here: https://github.com/JuliaLang/julia/issues/34884

But back to the original topic, I think the running time will be dominated by the matrix multiply.

Unfortunately this is not type-stable, because the dimensions of the output array depend on n.
Now in all the uses I will do of this function, n can actually be computed from type information.
So any suggestions on how I can make a type-stable version of this?

Here is a β€œtype stable” implementation, but if you benchmark, it doesn’t really matter for the runtime. The time is spent on matrix-matrix product. But if type stability is important for some other parts, then it may be useful:

 function tensordot(A::AbstractArray{T1,N1}, B::AbstractArray{T2,N2},::Val{n}) where {T1,T2,N1,N2,n}
           s1=ntuple(i->size(A,i),Val(N1-n))
           s2=ntuple(i->size(A,N1-i+1),Val(n))
           s3=ntuple(i->size(B,i),Val(n))
           s4=ntuple(i->size(B,i+n),Val(N2-n))
           s5=ntuple(i->( i<=N1-n ? s1[i] : s4[i-(N1-n)]),Val(N1+N2-2n))
           Amat = reshape(A, prod(s1), prod(s2))
           Bmat = reshape(B, prod(s3), prod(s4))
           Cmat = Amat * Bmat
           C = reshape(Cmat,s5)
           return C
           end

julia> A=rand(10,12,13,14);

julia> B=rand(13,14,15,16,17);

julia> tensordot(A, B, Val(2))==tensordot(A,B,2)
true

julia> @btime tensordot($A,$B,Val(2));
  958.343 ΞΌs (8 allocations: 3.74 MiB)

julia> @btime tensordot($A,$B,2);
  965.924 ΞΌs (23 allocations: 3.74 MiB)

julia> @code_warntype tensordot(A,B,Val(2))
Variables
  #self#::Core.Compiler.Const(tensordot, false)
  A::Array{Float64,4}
  B::Array{Float64,5}
  #unused#::Core.Compiler.Const(Val{2}(), false)
  #273::var"#273#278"{Array{Float64,4}}
  #274::var"#274#279"{4,Array{Float64,4}}
  #275::var"#275#280"{Array{Float64,5}}
  #276::var"#276#281"{2,Array{Float64,5}}
  #277::var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}
  s1::Tuple{Int64,Int64}
  s2::Tuple{Int64,Int64}
  s3::Tuple{Int64,Int64}
  s4::Tuple{Int64,Int64,Int64}
  s5::NTuple{5,Int64}
  Amat::Array{Float64,2}
  Bmat::Array{Float64,2}
  Cmat::Array{Float64,2}
  C::Array{Float64,5}

Body::Array{Float64,5}
1 ─ %1  = Main.:(var"#273#278")::Core.Compiler.Const(var"#273#278", false)
β”‚   %2  = Core.typeof(A)::Core.Compiler.Const(Array{Float64,4}, false)
β”‚   %3  = Core.apply_type(%1, %2)::Core.Compiler.Const(var"#273#278"{Array{Float64,4}}, false)
β”‚         (#273 = %new(%3, A))
β”‚   %5  = #273::var"#273#278"{Array{Float64,4}}
β”‚   %6  = ($(Expr(:static_parameter, 3)) - $(Expr(:static_parameter, 5)))::Core.Compiler.Const(2, false)
β”‚   %7  = Main.Val(%6)::Core.Compiler.Const(Val{2}(), true)
β”‚         (s1 = Main.ntuple(%5, %7))
β”‚   %9  = Main.:(var"#274#279")::Core.Compiler.Const(var"#274#279", false)
β”‚   %10 = $(Expr(:static_parameter, 3))::Core.Compiler.Const(4, false)
β”‚   %11 = Core.typeof(A)::Core.Compiler.Const(Array{Float64,4}, false)
β”‚   %12 = Core.apply_type(%9, %10, %11)::Core.Compiler.Const(var"#274#279"{4,Array{Float64,4}}, false)
β”‚         (#274 = %new(%12, A))
β”‚   %14 = #274::var"#274#279"{4,Array{Float64,4}}
β”‚   %15 = Main.Val($(Expr(:static_parameter, 5)))::Core.Compiler.Const(Val{2}(), true)
β”‚         (s2 = Main.ntuple(%14, %15))
β”‚   %17 = Main.:(var"#275#280")::Core.Compiler.Const(var"#275#280", false)
β”‚   %18 = Core.typeof(B)::Core.Compiler.Const(Array{Float64,5}, false)
β”‚   %19 = Core.apply_type(%17, %18)::Core.Compiler.Const(var"#275#280"{Array{Float64,5}}, false)
β”‚         (#275 = %new(%19, B))
β”‚   %21 = #275::var"#275#280"{Array{Float64,5}}
β”‚   %22 = Main.Val($(Expr(:static_parameter, 5)))::Core.Compiler.Const(Val{2}(), true)
β”‚         (s3 = Main.ntuple(%21, %22))
β”‚   %24 = Main.:(var"#276#281")::Core.Compiler.Const(var"#276#281", false)
β”‚   %25 = $(Expr(:static_parameter, 5))::Core.Compiler.Const(2, false)
β”‚   %26 = Core.typeof(B)::Core.Compiler.Const(Array{Float64,5}, false)
β”‚   %27 = Core.apply_type(%24, %25, %26)::Core.Compiler.Const(var"#276#281"{2,Array{Float64,5}}, false)
β”‚         (#276 = %new(%27, B))
β”‚   %29 = #276::var"#276#281"{2,Array{Float64,5}}
β”‚   %30 = ($(Expr(:static_parameter, 4)) - $(Expr(:static_parameter, 5)))::Core.Compiler.Const(3, false)
β”‚   %31 = Main.Val(%30)::Core.Compiler.Const(Val{3}(), true)
β”‚         (s4 = Main.ntuple(%29, %31))
β”‚   %33 = Main.:(var"#277#282")::Core.Compiler.Const(var"#277#282", false)
β”‚   %34 = $(Expr(:static_parameter, 3))::Core.Compiler.Const(4, false)
β”‚   %35 = $(Expr(:static_parameter, 5))::Core.Compiler.Const(2, false)
β”‚   %36 = Core.typeof(s1)::Core.Compiler.Const(Tuple{Int64,Int64}, false)
β”‚   %37 = Core.typeof(s4)::Core.Compiler.Const(Tuple{Int64,Int64,Int64}, false)
β”‚   %38 = Core.apply_type(%33, %34, %35, %36, %37)::Core.Compiler.Const(var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}, false)
β”‚   %39 = s1::Tuple{Int64,Int64}
β”‚         (#277 = %new(%38, %39, s4))
β”‚   %41 = #277::var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}
β”‚   %42 = ($(Expr(:static_parameter, 3)) + $(Expr(:static_parameter, 4)))::Core.Compiler.Const(9, false)
β”‚   %43 = (2 * $(Expr(:static_parameter, 5)))::Core.Compiler.Const(4, false)
β”‚   %44 = (%42 - %43)::Core.Compiler.Const(5, false)
β”‚         (s5 = Main.ntuple(%41, %44))
β”‚   %46 = Main.prod(s1)::Int64
β”‚   %47 = Main.prod(s2)::Int64
β”‚         (Amat = Main.reshape(A, %46, %47))
β”‚   %49 = Main.prod(s3)::Int64
β”‚   %50 = Main.prod(s4)::Int64
β”‚         (Bmat = Main.reshape(B, %49, %50))
β”‚         (Cmat = Amat * Bmat)
β”‚         (C = Main.reshape(Cmat, s5))
└──       return C

The trick is to make the sizes (N1,N2,n) available at compile time, Val and ntuple…

Cheers!

1 Like

@cossio Do you know about https://github.com/Jutho/TensorOperations.jl? Seems like it already implements most of what you’re trying to do.

Yes, but (at least from my knowledge of it) it doesn’t seem to support tensor products where the number of contracted dimensions is dynamic. Does it?

Also from the things I’ve tried, TensorOperations.jl is incompatible with Zygote because it modifies temporary arrays in-place.

These two things kept me away from it, but I could be mistaken.

Thanks. It seems that you don’t need Val in the body of the function. In fact it can be simplified to:

function tensordot(A::AbstractArray, B::AbstractArray, ::Val{n}) where {n}
		Amat = reshape(A, prod(size(A,i) for i = 1:ndims(A)-n), :)
		Bmat = reshape(B, prod(size(B,i) for i=1:n), :)
		Cmat = Amat * Bmat
		C = reshape(Cmat, ntuple(i -> i ≀ ndims(A) - n ? size(A,i) : size(B, i - ndims(A) + 2n), ndims(A) + ndims(B) - 2n))
		return C
end

and it will also be type stable!

@raminammour See edit.

Yes, I can never figure out when it is needed and when the compiler infers without it, so I have taken to the habit of including it just in case :slight_smile:

One other way you can do this is:

julia> using OMEinsum

julia> tensordot(A, B, 2) β‰ˆ ein" abcd, cdefg -> abefg "(A, B)
true

julia> @macroexpand  ein"abcd,cdefg -> abefg"(A, B)
:((EinCode{(('a', 'b', 'c', 'd'), ('c', 'd', 'e', 'f', 'g')),('a', 'b', 'e', 'f', 'g')}())(A, B))

You can write such codes at runtime. And the result should be Zygote-friendly.

However if you time it, it’s not as fast as your function here. I think it’s decomposing this into more operations than strictly necessary.

it doesn’t seem to support tensor products where the number of contracted dimensions is dynamic

You can do that with TensorOperations using the tensorcontract function.

1 Like

Here a link to the function @orialb mentioned: https://jutho.github.io/TensorOperations.jl/stable/functions/#TensorOperations.tensorcontract

1 Like

But that assumes you know a priori what dimensions to contract. Note that in the function I wrote above, n is an argument that comes from outside and is unknown to the function (even though it can be inferred at compile time, hence the Val).

Thanks. But that is not Zygote friendly, correct?

No:

You do not have to use the macro, and can make your function construct the EinCode object by itself.

I don’t know but would expect compatibility at least with disable_blas().

There is now also an ncon (and @ncon) function and macro in the latest version of TensorOperations.jl.

But indeed, autodiff support is still on the todo list. I am a bit overwhelmed by the autodiff packages in Julia. Will Zygote.jl become the community standard (or is it already)? And as a third-party package I should only include ZygoteRules.jl and define the appropriate adjoints? And how does it indeed deal with in-place modification?

3 Likes