Efficient trace of product of matrices


#1

In quantum statistical mechanics, it is common to compute the trace of the product of two (or more) matrices. The obvious way of expressing this in julia is tr(A * B); however, this results in the computation of all elements of A * B even though only the diagonal elements are required for the result, thus turning what could be an O(n^2) operation into an O(n^3) one.

Because this is highly suboptimal for large matrices, I often define

import Compat
function trprod(A::AbstractMatrix, B::AbstractMatrix)
    let Ah = A'
        size(Ah) == size(B) || throw(DimensionMismatch())
        Compat.dot(Ah, B)
    end
end

I prefer not to write Compat.dot(A', B) directly because I think this obscures what I mean (and besides, this would skip the check on the shapes of A and B). But I always wonder: is there an existing way of expressing this functionality in a clear and direct way that I’m simply unaware of?

Assuming there is no existing method, this pattern is so common that I think it may make sense to have a widely available package with this functionality. That said, I’m trying to figure out the appropriate generality and level of abstraction for such a package. Some thoughts:

  • It would be nice if optimized methods were to exist for when one or more matrices are Hermitian, since this is very common in quantum mechanics.
  • If the matrices are not all square, then due to the cyclic property of the trace, there may be a preferred method for computing trprod which minimizes the overall number of operations, given the shapes of the matrices.
  • More generally, one might wish to consider a method that explicitly computes only the diagonal of a product of matrices. Then one could use e.g. tr(diagprod(A, B)) (although this would result in an unnecessary allocation to store the diagonal). However, this is slightly at odds with the previous point, which relies on a special property of the trace for optimization. That said, would having such a diagprod function be useful in other contexts/domains?
  • One alternative to diagprod mentioned above would be to have a view that computes elements of a product of two matrices on-the-fly as they are needed. Then one could use tr(prodview(A, B)) to compute the trace without any need for allocation (although I would need to think more about whether this would result in efficient code from a cache-locality perspective). Note that computing elements on-the-fly would be inefficient for a product of more than two matrices, so prodview should accept no more than two arguments, and one would need to do e.g. prodview(A, B * C) explicitly (and pay attention to the optimal way to decompose the multiplication based on the relative sizes of the matrices). But might such a function be useful in other contexts that I have not considered here?

#2

How about?

using Compat
macro trace(expr)
	At = expr.args[2]
	A = :(($(esc(At)))')
	B = esc(expr.args[3])
	return :(Compat.dot($A, $B))
end

A = rand(3, 3)
@trace A' * A

#3

How about implementing it like this (maybe, make mohamed’s macro call this function)?

function trprod2(A,B)
    @boundscheck size(A) == size(B') || throw(BoundsError())
    out = zero(promote_type(eltype(A),eltype(B)))
    @inbounds for i ∈ 1:size(A,2)
        for j ∈ 0:8:size(A,1)-8
            Base.Cartesian.@nexprs 8 k -> out += A[k+j,i]*B[i,k+j]
        end
        for j ∈ size(A,1)+1-(size(A,1)%8):size(A,1)
            out += A[j,i] * B[i,j] 
        end
    end
    out
end
julia> A = randn(1000,1000); B = randn(1000,1000);

julia> @btime trprod($A, $B)
  1.632 ms (1 allocation: 16 bytes)
139.99033252698055

julia> @btime trprod2($A, $B)
  1.084 ms (0 allocations: 0 bytes)
139.99033252695608

Because the above is mostly just for loops in Julia, it should also be easier to modify for the case of Hermitian matrices.

EDIT: FWIW, simple @inbounds for loops take >1.1 ms on this computer. That’s probably good enough to just stick with

function trprod4(A,B)
    @boundscheck size(A) == size(B') || throw(BoundsError())
    out = zero(promote_type(eltype(A),eltype(B)))
    @inbounds for j ∈ 1:size(A,1), i ∈ 1:size(A,2)
        out += A[j,i]*B[i,j]
    end
    out
end

although, as you note, you could have a wrapper check the sizes of A and B, and then dispatch it so that the longer dimension get’s accessed down the columns, while the shorter dimension gets the bad memory accesses.

EDIT:

@generated function trprod3(A,B,::Val{U}=Val(4)) where U
    quote
        @boundscheck size(A) == size(B')
        out = zero(promote_type(eltype(A),eltype(B)))
        @inbounds for j ∈ 0:$U:size(A,2)-$U
            for i ∈ 1:size(A,1)
                Base.Cartesian.@nexprs $U k -> out += A[k+j,i]*B[i,k+j]
            end
        end
        @inbounds for j ∈ size(A,2)+1-(size(A,2)%$U):size(A,2), i ∈ 1:size(A,1)
                out += A[j,i] * B[i,j]
        end
        out
    end
end

This was a little faster ( < 1 ms ) with U=4, and much slower with U = 3, or U > 4. I thought this may access memory a little more efficiently, especially with U = 8, but that wasn’t really supported (> 2 ms).


#4

I assume A and B are Hermitian? Tr(AB) = Tr(A’B) is simply the natural (hilbert-schmidt) inner product on matrices, and is simply written in Julia dot(A,B).


#5

It seems the OP explicitly wants to avoid taking the complex conjugate, so I would assume that no, the matrices are not Hermitian.


#6

In all the uses I know at least one matrix (the observable, if not the state) is hermitian. I’d be interested to know the application when both matrices are non hermitian.


#7

I am looking for an idiomatic yet efficient way of expressing the trace of the product of two (or more) matrices. I mention above the reasons I consider Compat.dot alone to be suboptimal (though my example implementation relies on it, naturally).


#8

I do not understand your criticism : dot does check for sizes, is as short as can be, naturally expresses the inner product structure, and is optimally efficient (if it’s not, it’s worth a PR). It does not extend naturally to more than two objects though, but those I’ve encountered are best expressed as dot(a, b*c) or similar anyway.


#9

It checks the lengths, but not the sizes:

julia> A = rand(4,4);

julia> B = rand(8,2);

julia> dot(A,B)
3.8210283956714974

#10

Huh, I didn’t see that. That feels like a bug.


#11

I always understood it to be intentional (which is why I consider the trace of the product of matrices to be a distinct, more narrowly specified, operation). It would be nice to verify whether or not this is indeed intentional, but it sounds like we agree on everything else.


#12

Thank you everyone for the relies so far. (@mohamed82008 and @Elrod, I am still thinking about your posts but don’t want to delay thanking you :slightly_smiling_face:. )


#13

We’ll see: https://github.com/JuliaLang/julia/issues/28617


#14

Dot is not optimally efficient. Simple for loops are faster, at least for 1000x1000 matrices.
The problem is we need to take the transpose, and there aren’t optimized methods assuming arguments are transposed, unlike mul!.
Try this vs dot(A,B') or dot(A',B):

function trprod4(A,B)
    @boundscheck size(A) == size(B') || throw(BoundsError())
    out = zero(promote_type(eltype(A),eltype(B)))
    @inbounds for j ∈ 1:size(A,1), i ∈ 1:size(A,2)
        out += A[j,i]*B[i,j]
    end
    out
end

If both matrices are Hermitian, you can simply loop over the upper triangle, doubling all of the off-diagonal products. That’d probably be faster.


#15

FWIW, the check can be added to the macro. Also I see 2 different problems addressed here. First is dispatching to an efficient implementation and second is the syntactic sugar. Whether Compat.dot formerly vecdot is good enough or not is orthogonal to the syntax problem. Coming up with “optimized” implementations for different combinations of matrices is interesting but there is probably some trade off between allocations and cache locality. This should also be possible to formulate as a mapreduce for which you can try tmapreduce of KissThreading.jl to make use of some type stable shared memory parallelism.