Fast `diag(A' * B * A)`

I assume you meant that you wanted A to be k \times n? Here’s a few options:

#+begin_src julia
using Tullio, LoopVectorization

f1(A, B) = diag(A' * B * A)
f2(A, B) = vec(sum(A .* (B * A); dims=1))
f3(A, B) = @tullio D[i] := A[j, i] * B[j, k] * A[k, i]

function f4(A::Matrix{T}, B::Matrix{U}) where {T, U}
    V = promote_type(T, U)
    @assert size(A, 1) == size(B, 1) == size(B, 2)
    D = zeros(V, size(A, 2))
    @tturbo for i ∈ eachindex(D)
        for j ∈ axes(B, 1)
            for k ∈ axes(B, 2)
                D[i] += conj(A[j, i]) * B[j, k] * A[k, i]
            end
        end
    end
    D
end
f5(A, B) = map(LinearAlgebra.dot, eachcol(A), eachcol(B * A))

function f6(A::Matrix{T}, B::Matrix{U}) where {T, U}
    V = promote_type(T, U)
    @assert size(A, 1) == size(B, 1) == size(B, 2)
    D = Vector{promote_type(T,U)}(undef, size(A, 2))
    @tturbo for i ∈ eachindex(D)
        di = zero(eltype(D))
        for j ∈ axes(B, 1)
            for k ∈ axes(B, 2)
                di += conj(A[j, i]) * B[j, k] * A[k, i]
            end
        end
        D[i] = di
    end
end

let n = 1000, k = 10
    A = randn(k, n)
    B = randn(k, k)
    for f ∈ (f1, f2, f3, f4, f5, f6)
        print(f, "   ")
        @btime $f($A, $B)
    end
end;


#+end_src
#+RESULTS:
: f1     599.030 μs (5 allocations: 7.71 MiB)
: f2     21.550 μs (7 allocations: 164.36 KiB)
: f3     20.880 μs (1 allocation: 7.94 KiB)
: f4     4.636 μs (1 allocation: 7.94 KiB)
: f5     23.940 μs (3 allocations: 86.11 KiB)
: f6     4.481 μs (1 allocation: 7.94 KiB)

Looks like writing the manual loop in LoopVectorization.jl is the winner here.


Edit1: I missed on the of the suggestions above so I added it as f5.
Edit2: I messed up the order of indices for f4 so it was faster than it should have been. However, even with this fix it’s still the fastest option, just by a lesser margin.
Edit3: Fixed the problem pointed out here: Fast `diag(A' * B * A)` - #9 by mikmoore. This again has a negative performance impact on f4, but it’s still the fastest.
Edit4: I found yet another problem with f3 and f4 where I accidrntally wrote B[j, k] * A[j, i] instead of B[j, k] * A[k, i]. Fixing this slows down f3 and f4 a bit but f3 is hit harder than f4 and f4 remains the fastest.
Edit5: added another LoopVectorization.jl example that reduces the number of array accesses required so speeds things up a little.

11 Likes