In-place computation of Q (from QR decomposition)

I could use a fast computation of an orthonormalized basis to represent columns of a tall and narrow dense matrix.

Taking advantage of an idle hour, I came up with this modified Gram-Schmidt:

using LinearAlgebra

function coldot(A, j, i)
    m = size(A, 1)
    r = zero(eltype(A))
    @simd for k in 1:m
        r += A[k, i] * A[k, j]
    end
    return r; 
end

function colnorm(A, j)
    return sqrt(coldot(A, j, j)); 
end

function colsubt!(A, i, j, r)
    m = size(A, 1)
    @simd for k in 1:m
        A[k, i] -= r * A[k, j]
    end
end

function normalizecol!(A, j)
    m = size(A, 1)
    r = 1.0 / colnorm(A, j)
    @simd for k in 1:m
        A[k, j] *= r
    end
end

function mgsortho!(A)
    @show m, n = size(A)
    normalizecol!(A, 1)
    for j in 2:n 
        Base.Threads.@threads for i in j:n
            colsubt!(A, i, j-1, coldot(A, j-1, i))
        end
        normalizecol!(A, j)
    end
    return A
end

Initially I found it (predictably) slower than the built in qr!. For instance for a 200,000\times 200 matrix, my version was more than three times slower.

However, If I have a multiprocessor, I can use more threads and things change quite a bit.
For a matrix 2,000,000\times 400, using multiple threads

# threads          qr!       mgsortho!
1                 232 s           305 s
2                 216 s           170 s
4                 212 s           141 s

To conclude, no question but a challenge: care to make it faster still?

2 Likes

I don’t know how LAPACK makes its QR fast, but I’m sure it does QR by Householder reflections rather than Gram-Schmidt triangularization. And Householder is slower than Gram-Schmidt, by a factor of 2, I think.

Gram-Schmidt is prone to producing non-unitary Qs when applied to poorly conditioned matrices, and it produces less accurate answers to Ax=b problems than Householder. You lose twice as many digits of accuracy for a given poorly-conditioned matrix.

4 Likes

I think both algorithms run with 2mn^2 flops?

You’re right, pretty much. Trefethen and Bau give 2mn^2 - 2/3 \, n^3 for Householder and 2mn^2 for Gram-Schmidt, or 4/3\, m^3 versus 2m^3 for m=n It’s LU that’s better than Householder by a factor of two: 2/3\,m^3 for m=n.

TBH I didn’t look closely enough to see if those are just the factorizations or the total cost of an Ax=b solve

Thanks for looking it up. So that seems to confirm that I must be leaving some performance on the table, wouldn’t you agree?

The main trick in LAPACK is to exploit level-3 BLAS (matrix-matrix multiplication) to apply groups of Householders at once. It starts by doing a routine Householder QR on the first k columns using a sequence of Housholder transformations H_k H_{k-1} \cdots H_1 A and then exploits the fact that if you collect the Householder vectors in an n\times k matrix V, their product has a representation of the form

H_k \cdots H_1 = I - V T V^T

where T is triangular. LAPACK has a special routine dlartf for computing T from a set of Householder vectors. You can then apply the Householders to the last n-k columns in A[:,k+1:n] using

H_k \cdots H_1 A[:, k+1:n] = A[:, k+1:n] - V (T (V^T A[:,k+1:n]))

which has products that can be done efficiently using BLAS gemm which is highly optimized, threaded, and makes vastly better use of cache than algorithms heavily dependent on level 1 and 2 BLAS. You then repeat on the next set of k columns until you finish. The difference between modified Gram-Schmidt and this sort of blocked Householder QR is going to be less on such a tall matrix. You really win big when there are more columns to apply the blocked Householder to. But it’s still probably playing a big part in the difference on your example.

This sort of blocking can be done for modified Gram-Schmidt. Just orthogonalize the first k columns using the slower implementation to get an orthogonalized set in the columns of Q_1 (stored in A[:, 1:k]) and orthogonalize the last n-k columns against the first k using

(I - Q_1 Q_1^T) A[:, k+1:n] = A[:, k+1:n] - Q_1 (Q_1^T A[:,k+1:n])

using LinearAlgebra.BLAS.gemm!. Repeat until done I and think you should get performance that’s not too far off of what LAPACK gives if you play around with k.

Of course, numerically you will lose orthogonality roughly in proportion to u\kappa(A) for unit roundoff u if you don’t do any reorthogonalization. Modified Gram-Schmidt isn’t a great algorithm from that point of view. There is a trick to use an orthogonalization of the augmented matrix [A\,\, b] that allows backward stable solution of the associated least squares problem despite loss of orthogonality. (And which doesn’t work for classical Gram-Schmidt.)

Edit: It is interesting that the gap in performance is less for the 400 columns case than 200. That’s the opposite of what I would expect especially if gemm is using threads as well. Is the LAPACK QR running with a single thread? The results look more like that to me. But I haven’t tried running it yet… LinearAlgebra.BLAS.get_num_threads() might be worth checking.

7 Likes

I think something weird is going on. On my old and slow computer and 4 threads, with your code I get 246 seconds on a 2,000,000\times 400. For qr! I get

julia> BLAS.get_num_threads()
4

julia> size(A)
(2000000, 400)

julia> @time F = qr!(A);
 23.368239 seconds (96.30 k allocations: 5.317 MiB, 0.15% compilation time)

julia> @time F.Q * Matrix(I,2_000_000,400);
 27.737718 seconds (565.43 k allocations: 6.735 GiB, 0.38% gc time, 1.00% compilation time)

That’s really about what I would expect. A decent implementation using optimized level 3 BLAS (matrix-matrix multiplication) is usually a good 10 times faster than anything you can do with level 1 and 2 operations (vector-vector operations and matrix-vector multiply).

Of course qr! just computes the Householder vectors and stores them in A. The REPL does take a long time to display the factorization, I think because it is computing Q to display it. I’m guessing that’s why qr! is taking so much time in your example. Is it possible you are benchmarking the display time as well as computation?

However, even adding the computation of Q is not something that should take as long as the times you are getting. That’s what the F.Q computation is doing above in 28 seconds. I do seem to remember something really slow about the specific way Q is computed when it is displayed as part of the factorization in the REPL. I haven’t tracked it down again, but at some point I thought it used a method for printing AbstractArrays that indexes every element of Q, which then forced multiplying a vector by Householders for every single element, or something along those lines. Poking around in Cthulhu, when printing F.Q, I think that’s probably what’s happening, although I haven’t dug as deeply into it as I did before and my memory is a little hazy. In event, it moves toward methods that are supposed to work for an AbstractArray which will be a disaster performance-wise with the way F.Q is stored as Householder vectors.

I think if you introduce blocking into the MGS code you can probably get closer to the 50 seconds I’m seeing for qr! with the assembly of Q.

I’m using this code to benchmark:

@show Threads.nthreads()

m = 2000000
n = 400
A = rand(m, n)
B = deepcopy(A)

A .= B
mgsortho!(A)
@show norm(A'*A - LinearAlgebra.I)
A .= B
@time mgsortho!(A)

A .= B
@time begin
    q = qr!(A)
    A .= Matrix(q.Q)
end
@show norm(A'*A - LinearAlgebra.I)

That is because I really need a copy of the Q matrix.

That sounds wonderful. Could you give me a “for example” please?

You might consider orthogonalizing with classical Gram-Schmidt twice. It uses BLAS cals for the important stuff and is faster than modified Gram-Schmidt if you have more than two cores. My implementation of QR with this is not faster than the Householder version, but you do get Q and R explicitly.

Sounds good. Is there a code I could look at?

Yes, but it ain’t pretty. In the repository for my package SIAMFANL.jl the relevant files are
test/Chapter3/gmres_test.jl and the function you want is qrctk!
that function calls
src//Solvers/LinearSolvers/Orthogonalize!.jl

I use classical twice in my GMRES code and this stuff is there for CI. I’m working on a paper about classical twice but that’ll take a while longer.

1 Like

I’ll try a gemm! version to see how it works, although perhaps not in the next day, given some other things I need to do. I should have time Tuesday, maybe tomorrow. But I am wondering: why MGS? You can get Q out of what qr! gives you. It will be numerically orthogonal, which the Q from MGS won’t be unless you are fortunate enough to have a well conditioned matrix. MGS makes it easy to avoid allocating another separate tall matrix to store the Q in the non-Householder form. Is allocating an extra 2,000,000\times 400 matrix the concern?

Wonderful!

First it was due to an idea that I wanted to test out about selectively orthogonalizing.
But now I had a code working, and it lent itself to parallelization with threads.
The results were reported above (no clue where the difference between your measurements and mine come from).
Here are two sets of results on my Surface Pro (my code, then qr!):


$ for n in 1 2 4 8; do julia -t $n src/qrexp.jl; done    
                                                                  
Threads.nthreads() = 1                                                                                                                                                                  
(m, n) = size(A) = (200000, 400)                                                                                                                                                        
 19.648994 seconds (143.44 k allocations: 7.598 MiB, 0.63% compilation time)                                                                                                            
norm(A' * A - LinearAlgebra.I) = 4.35154e-14                                                                                                                                            
  6.351311 seconds (263.68 k allocations: 1.207 GiB, 0.40% gc time, 2.86% compilation time)                                                                                             
norm(A' * A - LinearAlgebra.I) = 5.91604e-14             
                                                                  
Threads.nthreads() = 2                                                                                                                                                                  
(m, n) = size(A) = (200000, 400)                                                                                                                                                        
 12.953728 seconds (146.39 k allocations: 7.866 MiB, 0.98% compilation time)                                                                                                            
norm(A' * A - LinearAlgebra.I) = 4.65118e-14                                                                                                                                            
  6.328120 seconds (263.68 k allocations: 1.207 GiB, 0.43% gc time, 2.87% compilation time)                                                                                             
norm(A' * A - LinearAlgebra.I) = 5.84090e-14             
                                                                  
Threads.nthreads() = 4                                                                                                                                                                  
(m, n) = size(A) = (200000, 400)                                                                                                                                                        
 10.379690 seconds (151.48 k allocations: 8.381 MiB, 1.24% compilation time)                                                                                                            
norm(A' * A - LinearAlgebra.I) = 4.35497e-14                                                                                                                                            
  6.209290 seconds (263.68 k allocations: 1.207 GiB, 0.42% gc time, 2.90% compilation time)                                                                                             
norm(A' * A - LinearAlgebra.I) = 5.82638e-14             
                                                                  
Threads.nthreads() = 8                                                                                                                                                                  
(m, n) = size(A) = (200000, 400)                                                                                                                                                        
  9.634670 seconds (161.45 k allocations: 9.406 MiB, 1.36% compilation time)                                                                                                            
norm(A' * A - LinearAlgebra.I) = 4.38790e-14                                                                                                                                            
  6.354643 seconds (263.68 k allocations: 1.207 GiB, 0.45% gc time, 2.85% compilation time)                                                                                             
norm(A' * A - LinearAlgebra.I) = 5.86100e-14                                                                                                                                            
                                                                                                                                                                                        
pkonl@Hedwig MINGW64 ~/Documents/00WIP/SubSIt.jl (main)                                                                                                                                 
$ for n in 1 2 4 8; do julia -t $n src/qrexp.jl; done    
                                                                  
Threads.nthreads() = 1                                                                                                                                                                  
(m, n) = size(A) = (2000000, 100)                                                                                                                                                       
 18.368969 seconds (141.64 k allocations: 7.401 MiB, 0.66% compilation time)                                                                                                            
norm(A' * A - LinearAlgebra.I) = 6.58181e-14                                                                                                                                            
  8.199471 seconds (263.68 k allocations: 2.993 GiB, 0.34% gc time, 2.22% compilation time)                                                                                             
norm(A' * A - LinearAlgebra.I) = 1.09271e-13             
                                                                  
Threads.nthreads() = 2                                                                                                                                                                  
(m, n) = size(A) = (2000000, 100)                                                                                                                                                       
 10.838415 seconds (142.36 k allocations: 7.468 MiB, 1.27% compilation time)                                                                                                            
norm(A' * A - LinearAlgebra.I) = 6.74323e-14                                                                                                                                            
  8.128605 seconds (263.68 k allocations: 2.993 GiB, 0.32% gc time, 2.60% compilation time)                                                                                             
norm(A' * A - LinearAlgebra.I) = 7.87493e-14             
                                                                  
Threads.nthreads() = 4                                                                                                                                                                  
(m, n) = size(A) = (2000000, 100)                                                                                                                                                       
  9.085712 seconds (143.69 k allocations: 7.599 MiB, 1.51% compilation time)                                                                                                            
norm(A' * A - LinearAlgebra.I) = 6.97238e-14                                                                                                                                            
  7.818083 seconds (263.68 k allocations: 2.993 GiB, 0.31% gc time, 2.23% compilation time)                                                                                             
norm(A' * A - LinearAlgebra.I) = 1.07916e-13             
                                                                  
Threads.nthreads() = 8                                                                                                                                                                  
(m, n) = size(A) = (2000000, 100)                                                                                                                                                       
  8.489505 seconds (146.23 k allocations: 7.856 MiB, 1.68% compilation time)                                                                                                            
norm(A' * A - LinearAlgebra.I) = 7.09861e-14                                                                                                                                            
  7.733180 seconds (263.68 k allocations: 2.993 GiB, 0.28% gc time, 2.27% compilation time)                                                                                             
norm(A' * A - LinearAlgebra.I) = 1.00242e-13                                                                                                                                            
                                                   

I remember now. It was the equivalent of Matrix(q.Q) that I noticed was slow before. If you go with

@time begin
  q = qr!(A)
  X = Matrix(I, size(A)...)
  Q = q.Q*X
end

It will be faster. Obviously that’s not great for allocating matrices. I’m not sure if there are any methods that allow in-place multiplication by q.Q.

You are right, materializing Q takes almost as long as computing the decomposition with qr!.

You need to preallocate all the matrices to get this to work without burdensome allocations. After that, it did not allocate much in my testing. If you are seeing a problem, please send me an example and I will look into it.

Sorry, getting the discourse thread and the messages convoluted.

lmul!(q.Q, Matrix{eltype(A)}(I, size(A)))

This seems to be slightly faster and allocate only half as much than Matrix(q.Q). Not sure what the latter is using and why it is not using this?

OK. Here is a quick-and-dirty LAPACK-style level 3 version that leverages your code along with a short test similar to what you were running. The only change I made to your code was to drop the @show from mgsortho!:

function mgsortho!(A)
    m, n = size(A)
    normalizecol!(A, 1)
    for j in 2:n 
        Base.Threads.@threads for i in j:n
            colsubt!(A, i, j-1, coldot(A, j-1, i))
        end
        normalizecol!(A, j)
    end
    return A
end

function mgsortho3!(A; block_size = 32)
  m, n = size(A)
  num_blocks, rem_block = divrem(n, block_size)
  work = zeros(eltype(A), block_size, n)
  @views for b = 0:(num_blocks - 1)
    c0 = block_size * b + 1
    c1 = block_size * (b + 1)
    A1 = A[:, c0:c1]
    A2 = A[:, (c1 + 1):n]
    work2 = work[:, (c1 + 1):n]
    mgsortho!(A1)
    mul!(work2, A1', A2, one(eltype(A)), zero(eltype(A)))
    mul!(A2, A1, work2, -one(eltype(A)), one(eltype(A)))
  end
  if rem_block != 0
    mgsortho!(@views A[:, (end - rem_block + 1):end])
  end
  return A
end

@show Threads.nthreads()
@show BLAS.get_num_threads()

m = 2000000
n = 400
A = rand(m, n)
B = deepcopy(A)

A .= B
mgsortho3!(A)
@show norm(A'*A - LinearAlgebra.I)
A .= B
@time mgsortho3!(A);

The output on a 6 core machine was

Threads.nthreads() = 6
BLAS.get_num_threads() = 6
norm(A' * A - LinearAlgebra.I) = 4.981915981648076e-13
 20.463087 seconds (14.66 k allocations: 1.419 MiB)

At this point it’s organized as a blocked version of MGS, but it’s probably not fair to call it MGS. In applying A1' you are computing inner products with multiple orthogonalized columns at once which seems a bit more like classical Gram-Schmidt. Maybe it is best interpreted as a blocked hybrid. Presumably you would want to incorporate reorthogonalization if you want it to be reliable. And you might also want R. Doing a real implementation would be a lot more work. But I think this at least illustrates the general approach.

2 Likes