Matrix chain multiplication with arbitrary number of matrices

Hi everyone,

I want to multiply (or contract) an arbitrary number of matrices.

As an example, the naive version with 4 matrices could be written as:

m1 = rand(10,10)
m2 = rand(10,10)
m3 = rand(10,10)
m4 = rand(10,10)

result = tr(*(m1,m2,m3,m4))

But what if the number of matrices must be established at run-time?

Furthermore, if I’m only interested in the trace of the computed matrix, as in the code snipped I wrote above, is there a way to compute just such a trace instead of the entire matrix in the first place?

I’m most probably looking for something very similar to the multi_dot Python function.

Thank you so much in advance to everyone.

M = reduce(*, [m1, m2, m3, m4])
tr(M)
1 Like

Do all of these matrices have the same size? Do you care about efficiency?

You can compute the the trace with one fewer matrix multiplication.

julia> M = [rand(10,10) for i = 1:4];

julia> tr(prod(M))
630.566913491192

julia> @views dot(M[1]', prod(M[2:end])) # one fewer matrix multiply
630.5669134911918

If the matrices are large and you only need to estimate the trace, then potentially there are even faster algorithms that avoid multiplying matrices entirely (by using only matrix–vector products). e.g. Hutchinson’s trace estimation algorithm.

If all the matrices have the same size, you can just use prod as noted above.

If the matrices have different sizes, in principle there is a matrix-chain ordering that minimizes the cost. multi_dot uses a heuristic algorithm to reduce this; an analogue in Julia is the MatrixChainMultiply.jl package, but it needs updating (it dates back to Julia 0.5!).

(In practice there seems to rarely be a need for automatic matrix-chain algorithms — if you are multiplying a bunch of matrices of different sizes, in practice it’s usually only 3–4 matrices and the optimal ordering is easily identified statically by hand. If you are multiplying an arbitrary number of matrices at runtime I’m guessing their sizes are all the same?)

5 Likes

Write it as a tensor reduction and look at GitHub - mcabbott/Tullio.jl: ⅀.

1 Like

The Tullio README at one point said:

Chained multiplication is also very slow, because it doesn’t know there’s a better
algorithm.

Not sure if this is still true?

1 Like

Yes you are right.

That’s brilliant. Thank you so much.

P.S. Thanks also to all the others who replied

1 Like