In the algorithm, v is always a standard unit vector. I’d avoid multiplying by a standard unit vector and replace the dot with an index, and the mul! with a view:
using LinearAlgebra, BenchmarkTools, Random, Distributions
function update_sm2!(Λ::AbstractMatrix{<:Real}, u::AbstractVector{<:Real}, v::AbstractVector{<:Real},
w::AbstractVector{<:Real}, z::AbstractVector{<:Real})
# performant way of doing Sherman-Morrison update
# Λ[= Λ - Λ * u*v'Λ ./ (1 .+ v'*Λ * u)
n = size(Λ, 1)
@assert size(Λ, 2) == n
@assert length(u) == n && length(v) == n
@assert length(w) == n && length(z) == n
mul!(w, Λ, u)
mul!(z, Λ', v)
denom = -1 / (1.0 + dot(v, w))
BLAS.ger!(denom, w, z, Λ)
return nothing
end
function update_sm3!(Λ::AbstractMatrix{<:Real}, u::AbstractVector{<:Real}, n_eff::Int,
w::AbstractVector{<:Real}, z::AbstractVector{<:Real})
# performant way of doing Sherman-Morrison update
# Λ[= Λ - Λ * u*v'Λ ./ (1 .+ v'*Λ * u)
n = size(Λ, 1)
@assert size(Λ, 2) == n
@assert length(u) == n && length(w) == n && length(z) == n
mul!(w, Λ, u)
z .= view(Λ, n_eff, :)
denom = -1 / (1.0 + w[n_eff])
BLAS.ger!(denom, w, z, Λ)
return nothing
end
function test()
n = 5000
A = rand(Uniform(), n, n)
A_inv = inv(A)
A_inv_2 = copy(A_inv)
v = zeros(Float64, n)
u = ones(Float64, n)
w = Vector{Float64}(undef, n)
z = Vector{Float64}(undef, n)
# Use the unit vector at this index
n_used = 1
v[n_used] = 1.
@btime update_sm2!($A_inv_2, $u,$v,$w,$z)
@btime update_sm3!($A_inv_2, $u, $n_used, $w,$z)
end
test()
gives
11.736 ms (0 allocations: 0 bytes)
8.711 ms (0 allocations: 0 bytes)
on my device, which seems like a modest further improvement