Computing Inverse of a stack of matrices

Hi everyone,

Just started this morning with Julia after a long overdue-:slight_smile:

I was looking for an efficient way to invert a stack of 3x3 matrices. Coming from a python background it does seem that numpy.linalg.inv is terribly slow at this.

I briefly looked if someone did anything similar here but couldn’t find anything (feel free to point in response if you have come across a similar post before), so here it goes

using BenchmarkTools
A = randn(1000,4,3,3)
function compute_inv(A)
    B = zeros(size(A));
    for i in 1:size(A)[1]
        for j in 1:size(A)[2]
            B[i,j,:,:] = inv(A[i,j,:,:]);
        end
    end
end

@benchmark compute_inv(A)
#   --------------
#   minimum time:     3.723 ms (0.00% GC)
#   median time:      4.688 ms (0.00% GC)
#   mean time:        4.779 ms (5.16% GC)
#   maximum time:     10.037 ms (10.81% GC)
#   --------------

This was after figuring out that inv accepts only square matrices and not higher dimensional arrays. I was surprised that it is actually much faster than numpy.linalg.inv yet a bit slower than a hard coded version. (Of course, I may not be doing things efficiently)

Would be happy to know on how to improve the performance of the above code.

If you have lots of tiny matrices, then StaticArrays is usually a good idea:

@btime compute_inv(A) setup=(A = randn(1000,4,3,3)); # 1.928 ms

function slice_inv(A) # expects matrix indices first
    B = similar(A)
    @inbounds for j in axes(A,4)
        for i in axes(A,3)
            B[:,:,i,j] .= inv(@view A[:,:,i,j]);
        end
    end
    B
end

@btime slice_inv(A) setup=(A = randn(3,3,4,1000)); # 1.260 ms

using StaticArrays

function slice_inv2(A::Array{T,4}) where {T}
    B = reinterpret(SArray{Tuple{3,3},T,2,9}, vec(A))
    C = map(inv, B)
    reshape(reinterpret(T, B), size(A))
end

@btime slice_inv2(A) setup=(A = randn(3,3,4,1000)); # 43.037 μs
3 Likes

Use an array of SMatrix (from StaticArrays.jl), which is about 25× faster (update: 75× faster if I fix my type declaration) than your compute_inv on my machine:

julia> using StaticArrays, BenchmarkTools

julia> A = rand(SMatrix{3,3,Float64}, 1000,4);

julia> @btime inv.($A);
  143.487 μs (4002 allocations: 593.83 KiB)

For such small matrices, generic routines that work for any size of matrix have a lot of overhead. The advantage of StaticArrays is huge here because it lets you invoke an unrolled, optimized inversion routine specifically for 3×3 matrices.

(I’m still surprised that it is reporting 4002 allocations, however; not sure why it requires a heap allocation for each element.) Update: I should have used SMatrix{3,3,Float64,9} as explained below.

(Note also that randn(1000,4,3,3) puts the dimensions in the wrong order for locality — you probably want the 3x3 matrices to be contiguous in memory. Storing things as an array of StaticArrays gets you contiguity automatically.)

6 Likes

It’s also worthwhile reconsidering what you are doing here. If you are coming from Python/Numpy, then it may have been ingrained that you should break every computation into a sequence of “vectorized” operations like inverting every 3×3 matrix at once.

In Julia, while there is nothing wrong with a “vectorized” call like inv.(A) above, you might want to re-think the context and whether you want this operation at all. To get the most out of a processor, the cardinal rule is to do as much work as possible with each input before moving on to the next input. (See also the blog post: Why vectorized code is not as fast as it could be.)

Presumably, you don’t just want the inverses for their own sake, but are planning on using them somehow, for example to solve a huge set of 3×3 equation systems, and potentially do then do further processing. For each 3×3 matrix, you might want to do all processing associated with that matrix (e.g. solve the system of equations, plus any pre/post-processing) before moving on to the next 3×3 matrix.

(But still use StaticArrays for working with 3×3 systems — that will always be a win.)

9 Likes

Because SMatrix{3,3,Float64} is an abstract type.

julia> SMatrix{3,3,Float64}
SArray{Tuple{3,3},Float64,2,L} where L

julia> A = rand(SMatrix{3,3,Float64}, 1000,4);

julia> @btime inv.($A);
  161.200 μs (4002 allocations: 593.83 KiB)

julia> A = rand(SMatrix{3,3,Float64,9}, 1000,4);

julia> @btime inv.($A);
  48.200 μs (2 allocations: 281.33 KiB)
6 Likes

Ah, thanks! I work more with SVector than SMatrix and forgot that SMatrix has an additional (somewhat redundant?) type parameter.

2 Likes

I think the reason it has the extra parameter is that a lot of algorithm cutoffs are based on the number of elements in the matrix, so being able to get that directly from the type domain probably saves a multiplication in a decent number of places.

Thanks. I’ll explore StaticArrays. I have marked Steven’s answer above as the solution, but this very well is too :slight_smile:

Thanks @stevengj for a superbly informative reply. I wasn’t aware of StaticArrays.jl. This is indeed very fast.

Absolutely :slight_smile: !!

So this basically ends up in a Finite Element calculation in elasticity where I have a deformation gradient tensor at each quadrature point inside my domain, hence the shape (3,3,1000,4). And since the calculations from hereon depend on the inverse of these (3,3) matrices, I was looking to increase the inversion speed.

Absolutely. Thanks a lot for such an enlightening response.

2 Likes

I see, thanks! I wrote things on the fly for a MWE following the numpy.linalg.inv convention. But the original shape for my arrays is indeed 3,3,1000,4. Using StaticArrays is the way to go :slight_smile:

Thanks a lot for posting this. As it stands, on my machine,

julia> A = rand(SMatrix{3,3,Float64,9}, 1000,4);
julia> @btime inv.($A)
  33.401 μs (2 allocations: 281.33 KiB)

vs the fastest (hard-coded) version using NumPy

from numpy.linalg import inv as npinv
from numpy import zeros_like, einsum, random
from timeit import timeit

def vdet(A):
    detA = zeros_like(A[0, 0])
    detA = A[0, 0] * (A[1, 1] * A[2, 2] - A[1, 2] * A[2, 1]) -\
           A[0, 1] * (A[2, 2] * A[1, 0] - A[2, 0] * A[1, 2]) +\
        A[0, 2] * (A[1, 0] * A[2, 1] - A[2, 0] * A[1, 1])
    return detA

def hdinv(A):
    invA = zeros_like(A)
    detA = vdet(A)

    invA[0, 0] = (-A[1, 2] * A[2, 1] +
                  A[1, 1] * A[2, 2]) / detA
    invA[1, 0] = (A[1, 2] * A[2, 0] -
                  A[1, 0] * A[2, 2]) / detA
    invA[2, 0] = (-A[1, 1] * A[2, 0] +
                  A[1, 0] * A[2, 1]) / detA
    invA[0, 1] = (A[0, 2] * A[2, 1] -
                  A[0, 1] * A[2, 2]) / detA
    invA[1, 1] = (-A[0, 2] * A[2, 0] +
                  A[0, 0] * A[2, 2]) / detA
    invA[2, 1] = (A[0, 1] * A[2, 0] -
                  A[0, 0] * A[2, 1]) / detA
    invA[0, 2] = (-A[0, 2] * A[1, 1] +
                  A[0, 1] * A[1, 2]) / detA
    invA[1, 2] = (A[0, 2] * A[1, 0] -
                  A[0, 0] * A[1, 2]) / detA
    invA[2, 2] = (-A[0, 1] * A[1, 0] +
                  A[0, 0] * A[1, 1]) / detA
    return invA
F = random.random((3,3,1000,4))
%timeit hdinv(F)
  179 µs ± 4.49 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

This is superb!!

2 Likes

3x3 matrix inversions don’t SIMD well, so it is faster when we vectorize across the stack:

using StaticArrays, BenchmarkTools

Asmat = rand(SMatrix{3,3,Float64,9}, 1000,4);
Bsmat = similar(A);
Aarray = permutedims(reshape(reinterpret(Float64, Asmat), (3,3,1000,4)), (3,4,1,2));
Barray = similar(Aarray);

Loop doing the inversions:

function many3x3inverts!(Y, X)
    @assert size(Y) === size(X)
    @assert size(Y,3) == size(Y,4) === 3
    @inbounds for j ∈ axes(Y,2); @simd ivdep for i ∈ axes(Y,1)
        
        X₁₁ = X[i,j,1,1]
        X₂₁ = X[i,j,2,1]
        X₃₁ = X[i,j,3,1]
        X₁₂ = X[i,j,1,2]
        X₂₂ = X[i,j,2,2]
        X₃₂ = X[i,j,3,2]
        X₁₃ = X[i,j,1,3]
        X₂₃ = X[i,j,2,3]
        X₃₃ = X[i,j,3,3]

        Y₁₁ = X₂₂*X₃₃ - X₂₃*X₃₂
        Y₂₁ = X₂₃*X₃₁ - X₂₁*X₃₃
        Y₃₁ = X₂₁*X₃₂ - X₂₂*X₃₁
        
        Y₁₂ = X₁₃*X₃₂ - X₁₂*X₃₃
        Y₂₂ = X₁₁*X₃₃ - X₁₃*X₃₁
        Y₃₂ = X₁₂*X₃₁ - X₁₁*X₃₂
        
        Y₁₃ = X₁₂*X₂₃ - X₁₃*X₂₂
        Y₂₃ = X₁₃*X₂₁ - X₁₁*X₂₃
        Y₃₃ = X₁₁*X₂₂ - X₁₂*X₂₁
        
        d = 1 / ( X₁₁*Y₁₁ + X₁₂*Y₂₁ + X₁₃*Y₃₁ )
        
        Y[i,j,1,1] = Y₁₁ * d
        Y[i,j,2,1] = Y₂₁ * d
        Y[i,j,3,1] = Y₃₁ * d
        Y[i,j,1,2] = Y₁₂ * d
        Y[i,j,2,2] = Y₂₂ * d
        Y[i,j,3,2] = Y₃₂ * d
        Y[i,j,1,3] = Y₁₃ * d
        Y[i,j,2,3] = Y₂₃ * d
        Y[i,j,3,3] = Y₃₃ * d

    end; end
end

Results:

julia> @btime $Bsmat .= inv.($Asmat);
  44.112 μs (0 allocations: 0 bytes)

julia> @btime many3x3inverts!($Barray, $Aarray)
  5.946 μs (0 allocations: 0 bytes)

julia> permutedims(reshape(reinterpret(Float64, Bsmat), (3,3,1000,4)), (3,4,1,2)) ≈ Barray
 true

EDIT: Also, it’d be better in general to flatten the first two dims to 4000x3x3 instead of 1000x4x3x3.

10 Likes

Thanks for posting this. Coincidentally, I was trying to compare the previous versions with a numba-accelerated python version (just to see if LLVM made any difference) but now with this, Julia is about 6x faster :slight_smile:

from numpy.linalg import inv as npinv
from numpy import zeros_like, einsum, random
from timeit import timeit
from numba import jit, prange

@jit(nopython=True, cache=True, nogil=True)
def vdet(F):
    J = zeros_like(F[0,0])
    for a in range(J.shape[0]):
        for b in range(J.shape[1]):
            J[a,b] += F[0, 0, a, b] * (F[1, 1, a, b] * F[2, 2, a, b] -
                                        F[1, 2, a, b] * F[2, 1, a, b]) -\
                    F[0, 1, a, b] * (F[1, 0, a, b] * F[2, 2, a, b] -
                                        F[1, 2, a, b] * F[2, 0, a, b]) +\
                    F[0, 2, a, b] * (F[1, 0,a ,b] * F[2, 1, a, b] -
                                        F[1, 1, a, b] * F[2, 0, a, b])
    return J


@jit(nopython=True, cache=True, nogil=True)
def vinv(F):
    J = vdet(F)
    Finv = zeros_like(F)
    for a in range(J.shape[0]):
        for b in range(J.shape[1]):
            Finv[0, 0, a, b] += (-F[1, 2, a, b] * F[2, 1, a, b] +
                                F[1, 1, a, b] * F[2, 2, a, b]) / J[a, b]
            Finv[1, 0, a, b] += (F[1, 2, a, b] * F[2, 0, a, b] -
                                F[1, 0, a, b] * F[2, 2, a, b]) / J[a, b]
            Finv[2, 0, a, b] += (-F[1, 1, a, b] * F[2, 0, a, b] +
                                F[1, 0, a, b] * F[2, 1, a, b]) / J[a, b]
            Finv[0, 1, a, b] += (F[0, 2, a, b] * F[2, 1, a, b] -
                                F[0, 1, a, b] * F[2, 2, a, b]) / J[a, b]
            Finv[1, 1, a, b] += (-F[0, 2, a, b] * F[2, 0, a, b] +
                                F[0, 0, a, b] * F[2, 2, a, b]) / J[a, b]
            Finv[2, 1, a, b] += (F[0, 1, a, b] * F[2, 0, a, b] -
                                F[0, 0, a, b] * F[2, 1, a, b]) / J[a, b]
            Finv[0, 2, a, b] += (-F[0, 2, a, b] * F[1, 1, a, b] +
                                F[0, 1, a, b] * F[1, 2, a, b]) / J[a, b]
            Finv[1, 2, a, b] += (F[0, 2, a, b] * F[1, 0, a, b] -
                                F[0, 0, a, b] * F[1, 2, a, b]) / J[a, b]
            Finv[2, 2, a, b] += (-F[0, 1, a, b] * F[1, 0, a, b] +
                                F[0, 0, a, b] * F[1, 1, a, b]) / J[a, b]
    return Finv
%timeit vinv(F)
  69.3 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

vs your code

julia> @btime many3x3inverts!($Barray, $Aarray)
  10.399 μs (0 allocations: 0 bytes)

It’s notoriously fast. Thanks for such an enlightening discussion.

2 Likes

I see, Thanks for pointing this out! But, in general, say if I wanted the array to be shaped 1000x4x3x3 instead of 4000x3x3 for further calculations downstream would flattening and reshaping it back be worth considering the performance ?

Adding @fastmath or (better, because @fastmath will likely be deprecated) using MuladdMacro and adding @muladd should get you under 10 microseconds by letting it use fma instructions (fused multiply-add).
Doing this, I got 5.4 microseconds with @fastmath and 5.5 with @muladd.

If it’s always going to be 1000x4, then no.
This version is fastest because it uses SIMD (Single Instruction Multiple Data). That is, most CPUs can operate on vectors of numbers at a time. With AVX512, for example, it’s nearly as fast to multiply two sets of 8 numbers as it is to just multiply 2 numbers. This is why I saw a nearly 8x speed increase.
You can use @code_llvm to see this; each operation is acting on 8 numbers:

  %wide.load53 = load <8 x double>, <8 x double>* %105, align 8
  %106 = fmul <8 x double> %wide.load49, %wide.load53
  %107 = fmul contract <8 x double> %wide.load50, %wide.load52
  %108 = fsub contract <8 x double> %106, %107
  %109 = fmul <8 x double> %wide.load47, %wide.load52
  %110 = fmul contract <8 x double> %wide.load46, %wide.load53
  %111 = fsub contract <8 x double> %109, %110
  %112 = fmul <8 x double> %wide.load46, %wide.load50
  %113 = fmul contract <8 x double> %wide.load47, %wide.load49
  %114 = fsub contract <8 x double> %112, %113
  %115 = fmul <8 x double> %wide.load50, %wide.load51
  %116 = fmul contract <8 x double> %wide.load48, %wide.load53
  %117 = fsub contract <8 x double> %115, %116
  %118 = fmul <8 x double> %wide.load, %wide.load53
  %119 = fmul contract <8 x double> %wide.load47, %wide.load51
  %120 = fsub contract <8 x double> %118, %119
  %121 = fmul <8 x double> %wide.load47, %wide.load48
  %122 = fmul contract <8 x double> %wide.load, %wide.load50
  %123 = fsub contract <8 x double> %121, %122
  %124 = fmul <8 x double> %wide.load48, %wide.load52
  %125 = fmul contract <8 x double> %wide.load49, %wide.load51
  %126 = fsub contract <8 x double> %124, %125
  %127 = fmul <8 x double> %wide.load46, %wide.load51
  %128 = fmul contract <8 x double> %wide.load, %wide.load52
  %129 = fsub contract <8 x double> %127, %128
  %130 = fmul <8 x double> %wide.load, %wide.load49
  %131 = fmul contract <8 x double> %wide.load46, %wide.load48
  %132 = fsub contract <8 x double> %130, %131
  %133 = fmul <8 x double> %wide.load, %108
  %134 = fmul contract <8 x double> %wide.load48, %111
  %135 = fadd contract <8 x double> %134, %133
  %136 = fmul contract <8 x double> %114, %wide.load51
  %137 = fadd contract <8 x double> %136, %135
  %138 = fdiv <8 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e
+00>, %137
  %139 = fmul <8 x double> %108, %138

This means on my computer, it’s fastest when we have multiples of 8.
For computers without AVX512 but with AVX, multiples of 4 will suffice.

The remainder are dong just 1 at a time.
Meaning, for example, if you had AVX and your arrays were 15 x 4 x 3 x 3, it’d spend about as long on the last three iterations as it did on the first 12.
By flattening it into 60 x 3 x 3, you don’t have any remainder at all anymore, and the entire thing can be done quickly.

This can be a big deal in simpler code where the compiler will often 4x unroll it as well, meaning instead of fast batches of 4 following by 1 at a time, it’ll do fast batches of 16 followed by 1 at a time.

But if you have 1000 along the first dimension, then iszero(size(A,1) % 4) (and iszero(size(A,1) % 8)), so it won’t really get any benefit from flattening.

FWIW, reshaping is fairly fast:

julia> @benchmark reshape(reshape($Barray3, (4000,3,3)), (1000,4,3,3))
 BenchmarkTools.Trial:
  memory estimate:  160 bytes
  allocs estimate:  2
  --------------
  minimum time:     52.799 ns (0.00% GC)
  median time:      54.727 ns (0.00% GC)
  mean time:        61.040 ns (7.78% GC)
  maximum time:     2.159 μs (95.25% GC)
  --------------
  samples:          10000
  evals/sample:     985
5 Likes

This is really nice and informative. Thanks again for such a detailed explanation.

1 Like

An alternative to StaticArrays is my Grassmann.jl package, which can do multi-linear algebra without ever defining higher order array interfaces. Newly released v0.6 deprecated StaticArrays

using Grassmann, StaticArrays
A = rand(SMatrix{3,3,Float64},1000,4);
B = Chain{ℝ3,1,Chain{ℝ3,1}}.(A);

The tests show that Chain{V,1,Chain{V,1}} is a bit slower at inv than SMatrix, yet it is actually many times faster at computing determinants as the det timing test shows:

julia> @btime inv.($A);
  128.939 μs (4002 allocations: 593.83 KiB)

julia> @btime inv.($B);
  181.841 μs (4002 allocations: 906.33 KiB)

julia> @btime det.($A);
  92.142 μs (4002 allocations: 93.83 KiB)

julia> @btime det.($B);
  14.172 μs (2 allocations: 31.33 KiB)

As you can see, Grassmann algebra provides a foundation (entirely independent of StaticArrays) for doing multi-linear algebra. It also is faster for some other methods like \ linear solve method.

Instead of using a traditional StaticArray interface, the Grassmann foundation is used. This is represented with a different memory layout, which can be used in interesting ways.

I’m sure there are many more optimizations possible on Grassmann in the future.

As mentioned above you are not using the concrete type:

julia> A = rand(SMatrix{3,3,Float64},1000,4);

julia> @btime inv.($A);
  108.500 μs (4002 allocations: 593.83 KiB)

julia> A = rand(SMatrix{3,3,Float64,9},1000,4);

julia> @btime inv.($A);
  27.700 μs (2 allocations: 281.33 KiB)
1 Like

I think the reason it has the extra parameter is that a lot of algorithm cutoffs are based on the number of elements in the matrix, so being able to get that directly from the type domain probably saves a multiplication in a decent number of places.

This may be a nice incidental advantage to having L (… though constant propagation could make the point of such benefits moot?) but I think the fundamental reason is that an SMatrix{N,M,...} is backed by an NTuple{N*M, ...}: since you can’t currently do arithmetic with type parameters/TypeVars (#18466) in struct constructors, this forces the existence of the otherwise redundant L.

3 Likes

(One could define a macro @SMatrix{n,m,T} that constructs SMatrix{n,m,T,n*m}, just to make it a bit easier to define a concrete type?)

1 Like