# Review: pth matrix roots

Hello,

to test the accuracy of the Schur-Pade approximation of `A^r` for `A` a matrix and `r` a real number, I implemented a method `rootm` to calculate `A^(1/q), q::Int`. This is a direct generalisation of `sqrtm = A^(1/2)` (PR20214) and would allow the accurate computation of `A^(p//q)` for any integers `p,q`. The algorithm solves `X^p = A` for `X` by getting a recurrence relation that directly derives from writing out the product `X^p`. Anyway, the performance of this code sucks. Why?

``````A = randn(127,127)
A = UpperTriangular(schurfact(A'*A)[:T])

@benchmark _sqrtm(A)
median time:      744.535 Ī¼s (0.00% GC)
@benchmark _rootm(A,2,Val{true})
median time:      145.209 ms (1.93% GC)

@profile _rootm(A,2,Val{true})
ProfileView.view()
``````

Using `@code_warntype`, I get no red ink. The algorithm is not that much more complex than `sqrtm`, but it is 200x slower. I donāt understand the profiling ā where does the inference come from, and why are there several distinct calls to `_rootm` in the above profile despite the function being called once?

Performance for `p`th roots with `p>2` is obviously the more interesting point, but `p=2` is a good benchmark and Iād like to understand the performance difference.

Two suggestions that occur to me when I glance at your code:

• Donāt call `sum` in your inner loops. Write out the loops.

• Donāt compute powers as `x^q` or similar in a loop over `q`: in each loop iteration, accumulate the product by multiplying repeatedly by `x` (essentially, use Hornerās method).

For example, replace `xij /= sum(xii^(p-1-q)*xjj^q for q in 0:(p-1))` by something like:

``````# compute xij /= sum(xii^(p-1-q)*xjj^q for q in 0:(p-1)):
ā = ā = xii^(p-1)
Ī¾ = xjj / xii
for q = 1:p-1
ā *= Ī¾
ā += ā
end
xij /= ā
``````
2 Likes

That worked!

Iām amazed:

``````  median time:      781.725 Ī¼s (0.00% GC)
``````

That is within 10% of the `sqrtm` function! The performance gain is the same for `rootm(A,p)` with `p>2`. Thank you Steven!

Almost all of this most have come from eliminating the `sum`s, because in this case `p=2` and `1 <= q < p` means `q = 1`, the cumulative powers eliminated using Hornerās method canāt have done much for the speedup. Is there analysis anywhere why/when `sum` is slow in Julia?

sum is a perfectly good function, but you have to realize that in a sum of only three elements any call of a general summation function (or any other function call) is going to have a lot of overhead compared to the trivial amount of work it is doing (just two additions). Two floating point additions are fast.

The cost of setting up a generator object adds even more overhead. If you were summing many numbers, this wouldnāt matter. But compared to the cost of two additions the creation of any complex object is massively costly.

Finally, computing a general exponentiation operation (albeit for an integer exponent, but which still involves a function call and various checks e.g. of the exponent sign and magnitude) for every summand is vastly more expensive than a single multiplication.

The rule of thumb is that the cheaper the operation you are performing, the more careful you have to be with high level abstractions.

5 Likes

(Looks like you also hit a type-inference bug: https://github.com/JuliaLang/julia/issues/20517)