Review: pth matrix roots

Hello,

to test the accuracy of the Schur-Pade approximation of A^r for A a matrix and r a real number, I implemented a method rootm to calculate A^(1/q), q::Int. This is a direct generalisation of sqrtm = A^(1/2) (PR20214) and would allow the accurate computation of A^(p//q) for any integers p,q. The algorithm solves X^p = A for X by getting a recurrence relation that directly derives from writing out the product X^p. Anyway, the performance of this code sucks. Why?

A = randn(127,127)
A = UpperTriangular(schurfact(A'*A)[:T])

@benchmark _sqrtm(A)
  median time:      744.535 Ī¼s (0.00% GC)
@benchmark _rootm(A,2,Val{true})
  median time:      145.209 ms (1.93% GC)

@profile _rootm(A,2,Val{true})
ProfileView.view()

Using @code_warntype, I get no red ink. The algorithm is not that much more complex than sqrtm, but it is 200x slower. I donā€™t understand the profiling ā€“ where does the inference come from, and why are there several distinct calls to _rootm in the above profile despite the function being called once?

Performance for pth roots with p>2 is obviously the more interesting point, but p=2 is a good benchmark and Iā€™d like to understand the performance difference.

Any comments and advice appreciated

Two suggestions that occur to me when I glance at your code:

  • Donā€™t call sum in your inner loops. Write out the loops.

  • Donā€™t compute powers as x^q or similar in a loop over q: in each loop iteration, accumulate the product by multiplying repeatedly by x (essentially, use Hornerā€™s method).

For example, replace xij /= sum(xii^(p-1-q)*xjj^q for q in 0:(p-1)) by something like:

# compute xij /= sum(xii^(p-1-q)*xjj^q for q in 0:(p-1)):
āˆ‘ = āˆ = xii^(p-1)
Ī¾ = xjj / xii
for q = 1:p-1
    āˆ *= Ī¾
    āˆ‘ += āˆ
end
xij /= āˆ‘
2 Likes

That worked! :tada:

Iā€™m amazed:

  median time:      781.725 Ī¼s (0.00% GC)

That is within 10% of the sqrtm function! The performance gain is the same for rootm(A,p) with p>2. Thank you Steven!

Almost all of this most have come from eliminating the sums, because in this case p=2 and 1 <= q < p means q = 1, the cumulative powers eliminated using Hornerā€™s method canā€™t have done much for the speedup. Is there analysis anywhere why/when sum is slow in Julia?

sum is a perfectly good function, but you have to realize that in a sum of only three elements any call of a general summation function (or any other function call) is going to have a lot of overhead compared to the trivial amount of work it is doing (just two additions). Two floating point additions are fast.

The cost of setting up a generator object adds even more overhead. If you were summing many numbers, this wouldnā€™t matter. But compared to the cost of two additions the creation of any complex object is massively costly.

Finally, computing a general exponentiation operation (albeit for an integer exponent, but which still involves a function call and various checks e.g. of the exponent sign and magnitude) for every summand is vastly more expensive than a single multiplication.

The rule of thumb is that the cheaper the operation you are performing, the more careful you have to be with high level abstractions.

5 Likes

(Looks like you also hit a type-inference bug: https://github.com/JuliaLang/julia/issues/20517)