I assume you meant that you wanted A
to be k \times n? Hereâs a few options:
#+begin_src julia
using Tullio, LoopVectorization
f1(A, B) = diag(A' * B * A)
f2(A, B) = vec(sum(A .* (B * A); dims=1))
f3(A, B) = @tullio D[i] := A[j, i] * B[j, k] * A[k, i]
function f4(A::Matrix{T}, B::Matrix{U}) where {T, U}
V = promote_type(T, U)
@assert size(A, 1) == size(B, 1) == size(B, 2)
D = zeros(V, size(A, 2))
@tturbo for i â eachindex(D)
for j â axes(B, 1)
for k â axes(B, 2)
D[i] += conj(A[j, i]) * B[j, k] * A[k, i]
end
end
end
D
end
f5(A, B) = map(LinearAlgebra.dot, eachcol(A), eachcol(B * A))
function f6(A::Matrix{T}, B::Matrix{U}) where {T, U}
V = promote_type(T, U)
@assert size(A, 1) == size(B, 1) == size(B, 2)
D = Vector{promote_type(T,U)}(undef, size(A, 2))
@tturbo for i â eachindex(D)
di = zero(eltype(D))
for j â axes(B, 1)
for k â axes(B, 2)
di += conj(A[j, i]) * B[j, k] * A[k, i]
end
end
D[i] = di
end
end
let n = 1000, k = 10
A = randn(k, n)
B = randn(k, k)
for f â (f1, f2, f3, f4, f5, f6)
print(f, " ")
@btime $f($A, $B)
end
end;
#+end_src
#+RESULTS:
: f1 599.030 Îźs (5 allocations: 7.71 MiB)
: f2 21.550 Îźs (7 allocations: 164.36 KiB)
: f3 20.880 Îźs (1 allocation: 7.94 KiB)
: f4 4.636 Îźs (1 allocation: 7.94 KiB)
: f5 23.940 Îźs (3 allocations: 86.11 KiB)
: f6 4.481 Îźs (1 allocation: 7.94 KiB)
Looks like writing the manual loop in LoopVectorization.jl is the winner here.
Edit1: I missed on the of the suggestions above so I added it as f5
.
Edit2: I messed up the order of indices for f4
so it was faster than it should have been. However, even with this fix itâs still the fastest option, just by a lesser margin.
Edit3: Fixed the problem pointed out here: Fast `diag(A' * B * A)` - #9 by mikmoore. This again has a negative performance impact on f4
, but itâs still the fastest.
Edit4: I found yet another problem with f3
and f4
where I accidrntally wrote B[j, k] * A[j, i]
instead of B[j, k] * A[k, i]
. Fixing this slows down f3
and f4
a bit but f3
is hit harder than f4
and f4
remains the fastest.
Edit5: added another LoopVectorization.jl example that reduces the number of array accesses required so speeds things up a little.