Faster way to peform rank-1 update to matrix inverse

I want to update a matrix inverse using the Sherman-Morrison formula

(A + uv’)^-1 = A^-1 - (A^-1 uv’ A^-1)/(1 + v’A^-1u)

for a large non-sparse matrix inside a loop (a coordinate descent algorithm). Here is a minimum working example for the best code I can find

#mwe rank-1 update

using LinearAlgebra, BenchmarkTools

A  = [[1,2] [3,4]]
A_inv = inv(A)
# want to update col 1 to [4,6]
v = [0.0, 1.0]
u = [1,2]
w = Vector{Float64}(undef, 2) 
z = Vector{Float64}(undef, 2) 

function update_sm!(Λ::AbstractMatrix{<:Real}, u::AbstractVector{<:Real}, v::AbstractVector{<:Real},
    w::AbstractVector{<:Real}, z::AbstractVector{<:Real})

    # performant way of doing Sherman-Morrison update
    # Λ[= Λ - Λ * u*v'Λ ./ (1 .+ v'*Λ * u)
    n = size(Λ, 1)
    @assert size(Λ, 2) == n
    @assert length(u) == n && length(v) == n
    @assert length(w) == n && length(z) == n

    # Step 2: compute w = Λ * u
    mul!(w, Λ, u) # this is faster than the previous version

    # Step 1: compute tmp1, but re-using step 2
    denom = 1.0 + dot(v, w) 

    # Step 3: compute z = Λ' * v (i.e., z[j] = sum_i Λ[i,j] * v[i])
    mul!(z, Λ', v) 

    # Step 4: Λ -= (1/denom) * w * z', swap order
    scale = 1.0 / denom
    @inbounds for j in 1:n
        fac = z[j] * scale
        for i in 1:n
            Λ[i, j] -= fac * w[i]
        end
    end

    return nothing
end

@btime update_lambda_sm_2!($A_inv, $u,$v,$w,$z)

Any way I can speed this up in Julia? The benchmarktools output is currently

55.695 ns (0 allocations: 0 bytes)
2 Likes

Are you actually working with 2x2 matrices?

  • If so, you should definitely use StaticArrays.jl
  • If so, you should benchmark on your realistic problem size, otherwise the results won’t be very helpful to you
4 Likes
# (A + uv’)^-1 = A^-1 - (A^-1 uv’ A^-1)/(1 + v’A^-1u)
function update_sm2!(Λ::AbstractMatrix{<:Real}, u::AbstractVector{<:Real}, v::AbstractVector{<:Real},
    w::AbstractVector{<:Real}, z::AbstractVector{<:Real})

    # performant way of doing Sherman-Morrison update
    # Λ[= Λ - Λ * u*v'Λ ./ (1 .+ v'*Λ * u)
    n = size(Λ, 1)
    @assert size(Λ, 2) == n
    @assert length(u) == n && length(v) == n
    @assert length(w) == n && length(z) == n

    mul!(w, Λ, u)
    mul!(z, Λ', v)
    denom = -1 / (1.0 + dot(v, w))
    BLAS.ger!(denom, w, z, Λ)

    return nothing
end

Test answer is the same:

# data
n = 2000
A  = randn(n, n)
A_inv = inv(A)
v = randn(n)
u = randn(n)
w = Vector{Float64}(undef, n) 
z = Vector{Float64}(undef, n) 

# compare answers
A_inv, A_inv2 = inv(A), inv(A)
update_sm!(A_inv, u,v,w,z) 
update_sm2!(A_inv2, u,v,w,z) 
all(A_inv .≈ A_inv2) # true

My version is equally fast when n is 100 and about 3x faster when n=2000

julia> @btime update_sm!($A_inv, $u,$v,$w,$z) # 41.2 ns
  3.228 ms (0 allocations: 0 bytes)

julia> @btime update_sm2!($A_inv2, $u,$v,$w,$z) # 59.2 ns
  1.038 ms (0 allocations: 0 bytes
3 Likes

I am working on a problem of size around 5000 by 5000, and repeating this 5000 times within the loop. Updated MWE as per biona001’s answer (thanks!)

using LinearAlgebra, BenchmarkTools, Random, Distributions
const n = 5000
A  = rand(Uniform(), n, n)
A_inv = inv(A)
A_inv_2 = copy(A_inv)
v = zeros(Float64, n)
u = ones(Float64, n)
w = Vector{Float64}(undef, n) 
z = Vector{Float64}(undef, n) 

v[1] = 1 # will cycle through setting v[i] = 0 within the loop

function update_sm!(Λ::AbstractMatrix{<:Real}, u::AbstractVector{<:Real}, v::AbstractVector{<:Real},
    w::AbstractVector{<:Real}, z::AbstractVector{<:Real})

    # performant way of doing Sherman-Morrison update
    # Λ[= Λ - Λ * u*v'Λ ./ (1 .+ v'*Λ * u)
    n = size(Λ, 1)
    @assert size(Λ, 2) == n
    @assert length(u) == n && length(v) == n
    @assert length(w) == n && length(z) == n

    # Step 2: compute w = Λ * u
    mul!(w, Λ, u) # this is faster than the previous version

    # Step 1: compute tmp1, but re-using step 2
    denom = 1.0 + dot(v, w) 

    # Step 3: compute z = Λ' * v (i.e., z[j] = sum_i Λ[i,j] * v[i])
    mul!(z, Λ', v) 

    # Step 4: Λ -= (1/denom) * w * z', swap order
    scale = 1.0 / denom
    @inbounds for j in 1:n
        fac = z[j] * scale
        for i in 1:n
            Λ[i, j] -= fac * w[i]
        end
    end

    return nothing
end

function update_sm2!(Λ::AbstractMatrix{<:Real}, u::AbstractVector{<:Real}, v::AbstractVector{<:Real},
    w::AbstractVector{<:Real}, z::AbstractVector{<:Real})

    # performant way of doing Sherman-Morrison update
    # Λ[= Λ - Λ * u*v'Λ ./ (1 .+ v'*Λ * u)
    n = size(Λ, 1)
    @assert size(Λ, 2) == n
    @assert length(u) == n && length(v) == n
    @assert length(w) == n && length(z) == n

    mul!(w, Λ, u)
    mul!(z, Λ', v)
    denom = -1 / (1.0 + dot(v, w))
    BLAS.ger!(denom, w, z, Λ)

    return nothing
end

@btime update_sm!($A_inv, $u,$v,$w,$z)
 21.846 ms (0 allocations: 0 bytes)

@btime update_sm2!($A_inv_2, $u,$v,$w,$z)
 12.478 ms (0 allocations: 0 bytes)

A_inv == A_inv_2

I get that the second function is about twice as fast. Any way to speed this up even more?

1 Like

In the algorithm, v is always a standard unit vector. I’d avoid multiplying by a standard unit vector and replace the dot with an index, and the mul! with a view:

using LinearAlgebra, BenchmarkTools, Random, Distributions

function update_sm2!(Λ::AbstractMatrix{<:Real}, u::AbstractVector{<:Real}, v::AbstractVector{<:Real},
    w::AbstractVector{<:Real}, z::AbstractVector{<:Real})

    # performant way of doing Sherman-Morrison update
    # Λ[= Λ - Λ * u*v'Λ ./ (1 .+ v'*Λ * u)
    n = size(Λ, 1)
    @assert size(Λ, 2) == n
    @assert length(u) == n && length(v) == n
    @assert length(w) == n && length(z) == n

    mul!(w, Λ, u)
    mul!(z, Λ', v)
    denom = -1 / (1.0 + dot(v, w))
    BLAS.ger!(denom, w, z, Λ)

    return nothing
end

function update_sm3!(Λ::AbstractMatrix{<:Real}, u::AbstractVector{<:Real}, n_eff::Int,
    w::AbstractVector{<:Real}, z::AbstractVector{<:Real})

    # performant way of doing Sherman-Morrison update
    # Λ[= Λ - Λ * u*v'Λ ./ (1 .+ v'*Λ * u)
    n = size(Λ, 1)
    @assert size(Λ, 2) == n
    @assert length(u) == n && length(w) == n && length(z) == n

    mul!(w, Λ, u)
    z .= view(Λ, n_eff, :)

    denom = -1 / (1.0 + w[n_eff])
    BLAS.ger!(denom, w, z, Λ)

    return nothing
end

function test()
    n = 5000
    A  = rand(Uniform(), n, n)
    A_inv = inv(A)
    A_inv_2 = copy(A_inv)
    v = zeros(Float64, n)
    u = ones(Float64, n)
    w = Vector{Float64}(undef, n) 
    z = Vector{Float64}(undef, n) 

    # Use the unit vector at this index
    n_used = 1
    v[n_used] = 1.
    @btime update_sm2!($A_inv_2, $u,$v,$w,$z)
    @btime update_sm3!($A_inv_2, $u, $n_used, $w,$z)
end
test()

gives

  11.736 ms (0 allocations: 0 bytes)
  8.711 ms (0 allocations: 0 bytes)

on my device, which seems like a modest further improvement

1 Like

My bad, view was the wrong way around (now fixed)

1 Like