Best way to take trace of matrix product in Julia?

My question is essentially a rehash of the Python question posted here. Given two matrices A and B, I would like to compute the trace of A*B, which does not require evaluating the full matrix multiplication. I’ve included a few implementations below:

using LinearAlgebra

A = rand(10000, 10000)
B = rand(10000, 10000)

function trace1(A, B)
    return tr(A*B)
end

function trace2(A, B)
    a = 0
    for i = 1:size(A, 1)
        for j = 1:size(A, 2)
            a += A[i, j]*B[j, i]
        end
    end
    return a
end

function trace3(A, B)
    a = 0
    for i = 1:size(A, 1)
        a += dot(A[i, :],B[:, i])
    end
    return a
end

function trace4(A, B)
    return sum(A.*B')
end

whose timing returns

  9.427852 seconds (3 allocations: 762.940 MiB, 0.28% gc time)
  0.405829 seconds (1 allocation: 16 bytes)
  0.601993 seconds (40.00 k allocations: 1.491 GiB, 6.45% gc time)
  0.907124 seconds (3 allocations: 762.940 MiB, 0.04% gc time)

So the second implementation appears to be the fastest. But is there any way to do better? And is there a cleaner way of doing it that doesn’t involve writing out the for loops explicitly? (I’m aware it could be cleaned up a little by being more clever with the indexing, but a single-line implementation would be nice),

trace5(A,B) = dot(A',B)

Note that

  • dot(A,B) is even faster, but this computes tr(A' * B), the Frobenius inner product. If you can re-arrange your surrounding calculations to use this (i.e. store your A transposed) it would be best by a large margin.
  • trace3 would be much better if you simply put @views in front of the function definition, so that it uses views for slices.
  • sum(A.*B') allocates a new array just to sum it, which slows it down.
  • Better to do this kind of benchmarking with BenchmarkTools.jl, e.g with @btime trace2($A, $B) to interpolate global variables. I would usually benchmark with a smaller matrix, e.g. 1000x1000, because I’m impatient.

If you really want to optimize the heck out of this, without transposing your A matrix so that you can use dot(A, B), then you have to start playing fancier games. A key problem is that A[i, j]*B[j, i] is going to access memory discontiguously for one of your two arrays, no matter which order you use for your loops, which will hurt spatial locality (cache-line utilization). To combat this, you’ll want to use some kind of blocking of your loops (or recursive/cache-oblivious blocking). Then you’ll also want to use e.g. LoopVectorization.jl to SIMD-optimize (and perhaps multi-thread) your loops. However, it’s rarely worth it to do this level of optimization unless profiling and analysis shows that this one little calculation is truly the limiting factor in your algorithm. And it would still be better to simply store A transposed so that you can use dot(A, B) (which can loop through both arrays in order, and uses a BLAS-1 call to boot).

21 Likes

Also, you should use something like

a = zero(A[begin] * B[begin])

rather than a = 0 to avoid type instability. You want a to initially be of the same type as it ends up after the calculation. In your example it starts out as an Int and ends up as a Float64. Fixing this will speed things up.

3 Likes

On my machine this is a little faster than trace5 for some cases:

function trace6(a::Matrix, b::Matrix)
    size(a) == reverse(size(b)) || error("Matrices must be conformable")
    return sum(a[i]*transpose(b)[i] for i in eachindex(a, transpose(b)))
end

For what it’s worth, using Tullio; @tullio out := A[i, j] * B[j, i] will do recursive blocking (and is 3x faster than dot(A', B) for me). If LoopVectorization.jl is loaded, then the macro uses that for SIMD (another 2x here, roughly matching dot(A, B) without the transpose).

11 Likes

tangentially…

julia> a = rand(3,3); b = rand(3,3)
3×3 Matrix{Float64}:
 0.683101  0.00563371  0.284401
 0.114074  0.369375    0.942211
 0.167023  0.618831    0.199613

julia> @which dot(a,b)
dot(x::Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, Base.ReshapedArray{T, N, A} where {N, A<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {T, N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}}, SubArray{T, var"#s990", var"#s128", I, true} where {var"#s990", var"#s128"<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, Base.ReshapedArray{T, N, A} where {N, A<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {T, N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}}, DenseArray{T}}, I}, DenseArray{T}}, y::Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, Base.ReshapedArray{T, N, A} where {N, A<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {T, N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}}, SubArray{T, var"#s990", var"#s128", I, true} where {var"#s990", var"#s128"<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, Base.ReshapedArray{T, N, A} where {N, A<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {T, N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}}, DenseArray{T}}, I}, DenseArray{T}}) where T<:Union{Float32, Float64}
     @ LinearAlgebra ~/Documents/github/dotFiles/homedir/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:15

Im not sure that is really true (though it’s probably safest to assume it is).
Recently I found that simply putting a=0 does not produce any warnings under @code_warntypeand my assumption is that the compiler will simply promote everything if you call the function with a float of Complex number.
Of course if such a promotion doesn’t exist then it would error.
Would love to hear some insight if I am missing anything

That must be a bug right? Did you report it?

Seems like that it is handled as a Union and performance is the same in this case:

julia> function trace2(A, B)
           a = 0
           for i = 1:size(A, 1)
               for j = 1:size(A, 2)
                   a += A[i, j]*B[j, i]
               end
           end
           return a
       end
trace2 (generic function with 1 method)

julia> function trace2_stable(A, B)
           a = zero(A[begin] * B[begin])
           for i = 1:size(A, 1)
               for j = 1:size(A, 2)
                   a += A[i, j]*B[j, i]
               end
           end
           return a
       end
trace2_stable (generic function with 1 method)

julia> A = randn((3,3)); B = randn((3,3));

julia> @code_warntype trace2(A, B)
MethodInstance for trace2(::Matrix{Float64}, ::Matrix{Float64})
  from trace2(A, B) @ Main REPL[18]:1
Arguments
  #self#::Core.Const(trace2)
  A::Matrix{Float64}
  B::Matrix{Float64}
Locals
  @_4::Union{Nothing, Tuple{Int64, Int64}}
  a::Union{Float64, Int64}
  @_6::Union{Nothing, Tuple{Int64, Int64}}
  i::Int64
  j::Int64
Body::Union{Float64, Int64}
1 ─       (a = 0)
│   %2  = Main.size(A, 1)::Int64
│   %3  = (1:%2)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_4 = Base.iterate(%3))
│   %5  = (@_4 === nothing)::Bool
│   %6  = Base.not_int(%5)::Bool
└──       goto #7 if not %6
2 ┄ %8  = @_4::Tuple{Int64, Int64}
│         (i = Core.getfield(%8, 1))
│   %10 = Core.getfield(%8, 2)::Int64
│   %11 = Main.size(A, 2)::Int64
│   %12 = (1:%11)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_6 = Base.iterate(%12))
│   %14 = (@_6 === nothing)::Bool
│   %15 = Base.not_int(%14)::Bool
└──       goto #5 if not %15
3 ┄ %17 = @_6::Tuple{Int64, Int64}
│         (j = Core.getfield(%17, 1))
│   %19 = Core.getfield(%17, 2)::Int64
│   %20 = a::Union{Float64, Int64}
│   %21 = Base.getindex(A, i, j)::Float64
│   %22 = Base.getindex(B, j, i)::Float64
│   %23 = (%21 * %22)::Float64
│         (a = %20 + %23)
│         (@_6 = Base.iterate(%12, %19))
│   %26 = (@_6 === nothing)::Bool
│   %27 = Base.not_int(%26)::Bool
└──       goto #5 if not %27
4 ─       goto #3
5 ┄       (@_4 = Base.iterate(%3, %10))
│   %31 = (@_4 === nothing)::Bool
│   %32 = Base.not_int(%31)::Bool
└──       goto #7 if not %32
6 ─       goto #2
7 ┄       return a


julia> @code_warntype trace2_stable(A, B)
MethodInstance for trace2_stable(::Matrix{Float64}, ::Matrix{Float64})
  from trace2_stable(A, B) @ Main REPL[19]:1
Arguments
  #self#::Core.Const(trace2_stable)
  A::Matrix{Float64}
  B::Matrix{Float64}
Locals
  @_4::Union{Nothing, Tuple{Int64, Int64}}
  a::Float64
  @_6::Union{Nothing, Tuple{Int64, Int64}}
  i::Int64
  j::Int64
Body::Float64
1 ─ %1  = Base.firstindex(A)::Core.Const(1)
│   %2  = Base.getindex(A, %1)::Float64
│   %3  = Base.firstindex(B)::Core.Const(1)
│   %4  = Base.getindex(B, %3)::Float64
│   %5  = (%2 * %4)::Float64
│         (a = Main.zero(%5))
│   %7  = Main.size(A, 1)::Int64
│   %8  = (1:%7)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_4 = Base.iterate(%8))
│   %10 = (@_4 === nothing)::Bool
│   %11 = Base.not_int(%10)::Bool
└──       goto #7 if not %11
2 ┄ %13 = @_4::Tuple{Int64, Int64}
│         (i = Core.getfield(%13, 1))
│   %15 = Core.getfield(%13, 2)::Int64
│   %16 = Main.size(A, 2)::Int64
│   %17 = (1:%16)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_6 = Base.iterate(%17))
│   %19 = (@_6 === nothing)::Bool
│   %20 = Base.not_int(%19)::Bool
└──       goto #5 if not %20
3 ┄ %22 = @_6::Tuple{Int64, Int64}
│         (j = Core.getfield(%22, 1))
│   %24 = Core.getfield(%22, 2)::Int64
│   %25 = a::Float64
│   %26 = Base.getindex(A, i, j)::Float64
│   %27 = Base.getindex(B, j, i)::Float64
│   %28 = (%26 * %27)::Float64
│         (a = %25 + %28)
│         (@_6 = Base.iterate(%17, %24))
│   %31 = (@_6 === nothing)::Bool
│   %32 = Base.not_int(%31)::Bool
└──       goto #5 if not %32
4 ─       goto #3
5 ┄       (@_4 = Base.iterate(%8, %15))
│   %36 = (@_4 === nothing)::Bool
│   %37 = Base.not_int(%36)::Bool
└──       goto #7 if not %37
6 ─       goto #2
7 ┄       return a

2 Likes

Seems you’re right. Looks like union splitting is still necessary here because the compiler does not know that a will never be used as an Int. Although Union splitting is alright, the completely type stable version is still a bit faster.

julia> @btime trace2($A1,$B1)
  9.500 μs (0 allocations: 0 bytes)
2472.71590461605

julia> @btime trace2_stable($A1,$B1)
  7.785 μs (0 allocations: 0 bytes)
2472.71590461605
2 Likes

Looks like there have been some improvements in recent versions of Julia. Under Julia 1.6.7, with @code_warntype the type union is highlighted in red, and the performance difference is greater:

julia> A = randn((3,3)); B = randn((3,3));

julia> @btime trace2($A,$B)
  14.114 ns (0 allocations: 0 bytes)
2.2702241217886

julia> @btime trace2_stable($A,$B)
  8.900 ns (0 allocations: 0 bytes)
2.2702241217886
1 Like

For setting the initial value in trace2 and trace3 I would suggest a = zero(eltype(A)) * zero(eltype(B)). This still works if the arrays happen to be empty.

But @mcabbott’s suggestion with Tullio and LoopVectorization seems to be the superior solution here by quite some margin. What an amazing package. (Though does anyone know why it has to do a single 16-byte allocation? Smells like some kind of box or dynamic dispatch?)

julia> using BenchmarkTools, LoopVectorization, Tullio

julia> A = randn(10000, 10000); B = randn(10000, 10000);

julia> tracetullio(A, B) = @tullio _ := A[i, j] * B[j, i]
tracetullio (generic function with 1 method)

julia> @btime tracetullio($A, $B);
  165.339 ms (1 allocation: 16 bytes)

Tullio uses Julia’s @spawn for multi-threading (if it thinks the arrays are large enough), which makes some small allocations.

1 Like

Hm… Are you sure that’s the culprit?

julia> A = randn(1, 1); B = randn(1, 1);

julia> @btime tracetullio($A, $B);
  25.524 ns (1 allocation: 16 bytes)

that’s not a bug, just a very elaborate method type annotation

I mean a bug in that Julia shows horribly complicated type annotations instead of the equivalent dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasReal} which one can see with @less

@less drops you to the source code, @which is showing you the real types, since StridedVecLike is an alias

Yes, my point is that it should not? Rather it should show the simpler aliases, like this:

julia> @which cross([1], [1])
cross(a::AbstractVector, b::AbstractVector)

julia> AbstractVector
AbstractVector (alias for AbstractArray{T, 1} where T)

I naively thought this simplification was done automatically, but maybe Julia simplifies only a few hard-coded cases? Then I wonder if it would be possible to auto-generate hardcoded simplification rules for all the type annotations that are used in the standard library…