# Fast type-stable tensor products

Let `A` and `B` be two multi-dimensional arrays. I want to form a product of them that contracts the last `n` dimensions of `A` with the first `n` dimensions of `B`. Calling this product `C`, one way to do it is with the following code:

``````
function tensordot(A::AbstractArray, B::AbstractArray, n::Int)
Amat = reshape(A, prod(size(A)[1:end-n]), prod(size(A)[end-n+1:end]))
Bmat = reshape(B, prod(size(B)[1:n]), prod(size(B)[n+1:end]))
Cmat = Amat * Bmat
C = reshape(Cmat, size(A)[1:end-n]..., size(B)[n+1:end]...)
return C
end

# example
A = randn(10,2,5,7);
B = randn(5,7,3,4,7);
tensordot(A, B, 2)
``````

Is this a fast way to do this kind of products in general? Or is there a better way?

Not really, itβs a single call to `*`.

If the arrays are small, then you may see some advantage to calculating the sizes for `reshape` using `ntuple`isms, and not making `n` global.

Not sure what you mean here.

So this is as fast as possible?

I mean you can squeeze out a ΞΌs or so by being careful when calculating these sizes. Whether this matters at all will depend on how big the arrays are:

``````julia> @btime prod(size(\$A)[end-2:end])
335.348 ns (2 allocations: 144 bytes)
70

julia> @btime prod(ntuple(d -> size(\$A,ndims(\$A)-d+1), 3))
1.421 ns (0 allocations: 0 bytes)
70
``````

Strange. Raised an issue here: https://github.com/JuliaLang/julia/issues/34884

But back to the original topic, I think the running time will be dominated by the matrix multiply.

Unfortunately this is not type-stable, because the dimensions of the output array depend on `n`.
Now in all the uses I will do of this function, `n` can actually be computed from type information.
So any suggestions on how I can make a type-stable version of this?

Here is a βtype stableβ implementation, but if you benchmark, it doesnβt really matter for the runtime. The time is spent on matrix-matrix product. But if type stability is important for some other parts, then it may be useful:

`````` function tensordot(A::AbstractArray{T1,N1}, B::AbstractArray{T2,N2},::Val{n}) where {T1,T2,N1,N2,n}
s1=ntuple(i->size(A,i),Val(N1-n))
s2=ntuple(i->size(A,N1-i+1),Val(n))
s3=ntuple(i->size(B,i),Val(n))
s4=ntuple(i->size(B,i+n),Val(N2-n))
s5=ntuple(i->( i<=N1-n ? s1[i] : s4[i-(N1-n)]),Val(N1+N2-2n))
Amat = reshape(A, prod(s1), prod(s2))
Bmat = reshape(B, prod(s3), prod(s4))
Cmat = Amat * Bmat
C = reshape(Cmat,s5)
return C
end

julia> A=rand(10,12,13,14);

julia> B=rand(13,14,15,16,17);

julia> tensordot(A, B, Val(2))==tensordot(A,B,2)
true

julia> @btime tensordot(\$A,\$B,Val(2));
958.343 ΞΌs (8 allocations: 3.74 MiB)

julia> @btime tensordot(\$A,\$B,2);
965.924 ΞΌs (23 allocations: 3.74 MiB)

julia> @code_warntype tensordot(A,B,Val(2))
Variables
#self#::Core.Compiler.Const(tensordot, false)
A::Array{Float64,4}
B::Array{Float64,5}
#unused#::Core.Compiler.Const(Val{2}(), false)
#273::var"#273#278"{Array{Float64,4}}
#274::var"#274#279"{4,Array{Float64,4}}
#275::var"#275#280"{Array{Float64,5}}
#276::var"#276#281"{2,Array{Float64,5}}
#277::var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}
s1::Tuple{Int64,Int64}
s2::Tuple{Int64,Int64}
s3::Tuple{Int64,Int64}
s4::Tuple{Int64,Int64,Int64}
s5::NTuple{5,Int64}
Amat::Array{Float64,2}
Bmat::Array{Float64,2}
Cmat::Array{Float64,2}
C::Array{Float64,5}

Body::Array{Float64,5}
1 β %1  = Main.:(var"#273#278")::Core.Compiler.Const(var"#273#278", false)
β   %2  = Core.typeof(A)::Core.Compiler.Const(Array{Float64,4}, false)
β   %3  = Core.apply_type(%1, %2)::Core.Compiler.Const(var"#273#278"{Array{Float64,4}}, false)
β         (#273 = %new(%3, A))
β   %5  = #273::var"#273#278"{Array{Float64,4}}
β   %6  = (\$(Expr(:static_parameter, 3)) - \$(Expr(:static_parameter, 5)))::Core.Compiler.Const(2, false)
β   %7  = Main.Val(%6)::Core.Compiler.Const(Val{2}(), true)
β         (s1 = Main.ntuple(%5, %7))
β   %9  = Main.:(var"#274#279")::Core.Compiler.Const(var"#274#279", false)
β   %10 = \$(Expr(:static_parameter, 3))::Core.Compiler.Const(4, false)
β   %11 = Core.typeof(A)::Core.Compiler.Const(Array{Float64,4}, false)
β   %12 = Core.apply_type(%9, %10, %11)::Core.Compiler.Const(var"#274#279"{4,Array{Float64,4}}, false)
β         (#274 = %new(%12, A))
β   %14 = #274::var"#274#279"{4,Array{Float64,4}}
β   %15 = Main.Val(\$(Expr(:static_parameter, 5)))::Core.Compiler.Const(Val{2}(), true)
β         (s2 = Main.ntuple(%14, %15))
β   %17 = Main.:(var"#275#280")::Core.Compiler.Const(var"#275#280", false)
β   %18 = Core.typeof(B)::Core.Compiler.Const(Array{Float64,5}, false)
β   %19 = Core.apply_type(%17, %18)::Core.Compiler.Const(var"#275#280"{Array{Float64,5}}, false)
β         (#275 = %new(%19, B))
β   %21 = #275::var"#275#280"{Array{Float64,5}}
β   %22 = Main.Val(\$(Expr(:static_parameter, 5)))::Core.Compiler.Const(Val{2}(), true)
β         (s3 = Main.ntuple(%21, %22))
β   %24 = Main.:(var"#276#281")::Core.Compiler.Const(var"#276#281", false)
β   %25 = \$(Expr(:static_parameter, 5))::Core.Compiler.Const(2, false)
β   %26 = Core.typeof(B)::Core.Compiler.Const(Array{Float64,5}, false)
β   %27 = Core.apply_type(%24, %25, %26)::Core.Compiler.Const(var"#276#281"{2,Array{Float64,5}}, false)
β         (#276 = %new(%27, B))
β   %29 = #276::var"#276#281"{2,Array{Float64,5}}
β   %30 = (\$(Expr(:static_parameter, 4)) - \$(Expr(:static_parameter, 5)))::Core.Compiler.Const(3, false)
β   %31 = Main.Val(%30)::Core.Compiler.Const(Val{3}(), true)
β         (s4 = Main.ntuple(%29, %31))
β   %33 = Main.:(var"#277#282")::Core.Compiler.Const(var"#277#282", false)
β   %34 = \$(Expr(:static_parameter, 3))::Core.Compiler.Const(4, false)
β   %35 = \$(Expr(:static_parameter, 5))::Core.Compiler.Const(2, false)
β   %36 = Core.typeof(s1)::Core.Compiler.Const(Tuple{Int64,Int64}, false)
β   %37 = Core.typeof(s4)::Core.Compiler.Const(Tuple{Int64,Int64,Int64}, false)
β   %38 = Core.apply_type(%33, %34, %35, %36, %37)::Core.Compiler.Const(var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}, false)
β   %39 = s1::Tuple{Int64,Int64}
β         (#277 = %new(%38, %39, s4))
β   %41 = #277::var"#277#282"{4,2,Tuple{Int64,Int64},Tuple{Int64,Int64,Int64}}
β   %42 = (\$(Expr(:static_parameter, 3)) + \$(Expr(:static_parameter, 4)))::Core.Compiler.Const(9, false)
β   %43 = (2 * \$(Expr(:static_parameter, 5)))::Core.Compiler.Const(4, false)
β   %44 = (%42 - %43)::Core.Compiler.Const(5, false)
β         (s5 = Main.ntuple(%41, %44))
β   %46 = Main.prod(s1)::Int64
β   %47 = Main.prod(s2)::Int64
β         (Amat = Main.reshape(A, %46, %47))
β   %49 = Main.prod(s3)::Int64
β   %50 = Main.prod(s4)::Int64
β         (Bmat = Main.reshape(B, %49, %50))
β         (Cmat = Amat * Bmat)
β         (C = Main.reshape(Cmat, s5))
βββ       return C
``````

The trick is to make the sizes (`N1,N2,n`) available at compile time, `Val` and `ntuple`β¦

Cheers!

1 Like

@cossio Do you know about https://github.com/Jutho/TensorOperations.jl? Seems like it already implements most of what youβre trying to do.

Yes, but (at least from my knowledge of it) it doesnβt seem to support tensor products where the number of contracted dimensions is dynamic. Does it?

Also from the things Iβve tried, `TensorOperations.jl` is incompatible with Zygote because it modifies temporary arrays in-place.

These two things kept me away from it, but I could be mistaken.

Thanks. It seems that you donβt need `Val` in the body of the function. In fact it can be simplified to:

``````function tensordot(A::AbstractArray, B::AbstractArray, ::Val{n}) where {n}
Amat = reshape(A, prod(size(A,i) for i = 1:ndims(A)-n), :)
Bmat = reshape(B, prod(size(B,i) for i=1:n), :)
Cmat = Amat * Bmat
C = reshape(Cmat, ntuple(i -> i β€ ndims(A) - n ? size(A,i) : size(B, i - ndims(A) + 2n), ndims(A) + ndims(B) - 2n))
return C
end
``````

and it will also be type stable!

@raminammour See edit.

Yes, I can never figure out when it is needed and when the compiler infers without it, so I have taken to the habit of including it just in case

One other way you can do this is:

``````julia> using OMEinsum

julia> tensordot(A, B, 2) β ein" abcd, cdefg -> abefg "(A, B)
true

julia> @macroexpand  ein"abcd,cdefg -> abefg"(A, B)
:((EinCode{(('a', 'b', 'c', 'd'), ('c', 'd', 'e', 'f', 'g')),('a', 'b', 'e', 'f', 'g')}())(A, B))
``````

You can write such codes at runtime. And the result should be Zygote-friendly.

However if you time it, itβs not as fast as your function here. I think itβs decomposing this into more operations than strictly necessary.

it doesnβt seem to support tensor products where the number of contracted dimensions is dynamic

You can do that with `TensorOperations` using the `tensorcontract` function.

1 Like

Here a link to the function @orialb mentioned: https://jutho.github.io/TensorOperations.jl/stable/functions/#TensorOperations.tensorcontract

1 Like

But that assumes you know a priori what dimensions to contract. Note that in the function I wrote above, `n` is an argument that comes from outside and is unknown to the function (even though it can be inferred at compile time, hence the `Val`).

Thanks. But that is not Zygote friendly, correct?

No:

You do not have to use the macro, and can make your function construct the EinCode object by itself.

I donβt know but would expect compatibility at least with `disable_blas()`.

There is now also an `ncon` (and `@ncon`) function and macro in the latest version of TensorOperations.jl.

But indeed, autodiff support is still on the todo list. I am a bit overwhelmed by the autodiff packages in Julia. Will Zygote.jl become the community standard (or is it already)? And as a third-party package I should only include ZygoteRules.jl and define the appropriate adjoints? And how does it indeed deal with in-place modification?

3 Likes