Review: pth matrix roots

Hello,

to test the accuracy of the Schur-Pade approximation of A^r for A a matrix and r a real number, I implemented a method rootm to calculate A^(1/q), q::Int. This is a direct generalisation of sqrtm = A^(1/2) (PR20214) and would allow the accurate computation of A^(p//q) for any integers p,q. The algorithm solves X^p = A for X by getting a recurrence relation that directly derives from writing out the product X^p. Anyway, the performance of this code sucks. Why?

A = randn(127,127)
A = UpperTriangular(schurfact(A'*A)[:T])

@benchmark _sqrtm(A)
  median time:      744.535 μs (0.00% GC)
@benchmark _rootm(A,2,Val{true})
  median time:      145.209 ms (1.93% GC)

@profile _rootm(A,2,Val{true})
ProfileView.view()

Using @code_warntype, I get no red ink. The algorithm is not that much more complex than sqrtm, but it is 200x slower. I don’t understand the profiling – where does the inference come from, and why are there several distinct calls to _rootm in the above profile despite the function being called once?

Performance for pth roots with p>2 is obviously the more interesting point, but p=2 is a good benchmark and I’d like to understand the performance difference.

Any comments and advice appreciated

Two suggestions that occur to me when I glance at your code:

  • Don’t call sum in your inner loops. Write out the loops.

  • Don’t compute powers as x^q or similar in a loop over q: in each loop iteration, accumulate the product by multiplying repeatedly by x (essentially, use Horner’s method).

For example, replace xij /= sum(xii^(p-1-q)*xjj^q for q in 0:(p-1)) by something like:

# compute xij /= sum(xii^(p-1-q)*xjj^q for q in 0:(p-1)):
āˆ‘ = āˆ = xii^(p-1)
ξ = xjj / xii
for q = 1:p-1
    āˆ *= ξ
    āˆ‘ += āˆ
end
xij /= āˆ‘
2 Likes

That worked! :tada:

I’m amazed:

  median time:      781.725 μs (0.00% GC)

That is within 10% of the sqrtm function! The performance gain is the same for rootm(A,p) with p>2. Thank you Steven!

Almost all of this most have come from eliminating the sums, because in this case p=2 and 1 <= q < p means q = 1, the cumulative powers eliminated using Horner’s method can’t have done much for the speedup. Is there analysis anywhere why/when sum is slow in Julia?

sum is a perfectly good function, but you have to realize that in a sum of only three elements any call of a general summation function (or any other function call) is going to have a lot of overhead compared to the trivial amount of work it is doing (just two additions). Two floating point additions are fast.

The cost of setting up a generator object adds even more overhead. If you were summing many numbers, this wouldn’t matter. But compared to the cost of two additions the creation of any complex object is massively costly.

Finally, computing a general exponentiation operation (albeit for an integer exponent, but which still involves a function call and various checks e.g. of the exponent sign and magnitude) for every summand is vastly more expensive than a single multiplication.

The rule of thumb is that the cheaper the operation you are performing, the more careful you have to be with high level abstractions.

5 Likes

(Looks like you also hit a type-inference bug: https://github.com/JuliaLang/julia/issues/20517)