Slow L-BFGS code

I’m customizing Burer-Monteiro method on a specific problem, which involves solving many subproblems using L-BFGS. However, I noticed that my L-BFGS runs slow. The critical operation of computing the descent direction is about 2x slower than I expected.

I noticed there are mainly two performance-critical operations in my descent direction computation, one is dot(A, B) where A, B are matrices and one is @. A += alpha * B, or in other words daxpy operation. I benchmarked them separately and I found the running time of computing the descent direction is 2x slower than the summation of running time of all performance-critical operations.

Here is one MWE:

using LinearAlgebra, SparseArrays, BenchmarkTools
using Random


# for reproducing
Random.seed!(11235813)

"""
Vector of L-BFGS
"""
struct LBFGSVector{T <: AbstractFloat}
    # notice that we use matrix instead 
    # of vector to store s and y because our
    # decision variables are matrices
    # s = xβ‚–β‚Šβ‚ - xβ‚– 
    s::Matrix{T}
    # y = βˆ‡ f(xβ‚–β‚Šβ‚) - βˆ‡ f(xβ‚–)
    y::Matrix{T}
    # ρ = 1/(⟨y, s⟩)
    ρ::Base.RefValue{T}
    # temporary variable
    a::Base.RefValue{T}
end

"""
History of l-bfgs vectors
"""
struct LBFGSHistory{Ti <: Integer, Tv <: AbstractFloat}
    # number of l-bfgs vectors
    m::Ti
    vecs::Vector{LBFGSVector{Tv}}
    # the index of the latest l-bfgs vector
    # we use a cyclic array to store l-bfgs vectors
    latest::Base.RefValue{Ti}
end


Base.:length(lbfgshis::LBFGSHistory) = lbfgshis.m


"""
Computing the descent direction, here I omit some details like 
negating the direction to highlight the main performance issue.
"""
function LBFGS_dir!(
    dir::Matrix{Tv},
    lbfgshis::LBFGSHistory{Ti, Tv};
) where {Ti <: Integer, Tv <: AbstractFloat}
    m = lbfgshis.m
    lst = lbfgshis.latest[]
    #here, dir, s and y are all matrices
    j = lst
    for i = 1:m 
        α = lbfgshis.vecs[j].ρ[] * dot(lbfgshis.vecs[j].s, dir)
        @. dir -= lbfgshis.vecs[j].y * Ξ± 
        lbfgshis.vecs[j].a[] = Ξ±
        j -= 1
        if j == 0
            j = m
        end
    end

    j = mod(lst, m) + 1
    for i = 1:m 
        β = lbfgshis.vecs[j].ρ[] * dot(lbfgshis.vecs[j].y, dir)
        Ξ³ = lbfgshis.vecs[j].a[] - Ξ²
        @. dir += lbfgshis.vecs[j].s * Ξ³ 
        j += 1
        if j == m + 1
            j = 1
        end
    end
end

# Benchmark code
numlbfgsvecs = 4 
n = 8000
r = 41
R = randn(n, r)
dir = randn(n, r)
lbfgshis = LBFGSHistory{Int64, Float64}(numlbfgsvecs, LBFGSVector{Float64}[], Ref(numlbfgsvecs))

for i = 1:numlbfgsvecs
    push!(lbfgshis.vecs, 
        LBFGSVector(similar(R), similar(R), Ref(randn(Float64)), Ref(randn(Float64))))
end

@benchmark LBFGS_dir!($dir, $lbfgshis)

My benchmark results looks like:

BenchmarkTools.Trial: 864 samples with 1 evaluation.
 Range (min … max):  5.739 ms …  11.223 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     5.775 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   5.785 ms Β± 187.210 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

                 β–ƒβ–‚β–„β–„β–ˆβ–…β–‚β–‚β–„β–β–β–‚                                  
  β–ƒβ–β–‚β–‚β–ƒβ–ƒβ–ƒβ–ƒβ–„β–ƒβ–„β–†β–‡β–‡β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–…β–†β–ƒβ–…β–„β–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–β–‚β–β–ƒβ–‚β–β–‚β–‚β–‚β–β–β–ƒ β–„
  5.74 ms         Histogram: frequency by time        5.84 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

and I benchmarked the critical operations, below are the results.

@benchmark dot($R, $dir)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  11.322 ΞΌs … 257.907 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     13.815 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   14.321 ΞΌs Β±   3.010 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

               β–‚β–ƒβ–ƒβ–†β–†β–ˆβ–†β–‡β–†β–†β–ƒβ–‚                                     
  β–β–β–β–β–β–β–‚β–ƒβ–ƒβ–…β–…β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–…β–„β–„β–ƒβ–ƒβ–‚β–‚β–‚β–‚β–β–‚β–‚β–‚β–‚β–ƒβ–ƒβ–ƒβ–ƒβ–„β–„β–„β–„β–„β–ƒβ–„β–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚ β–ƒ
  11.3 ΞΌs         Histogram: frequency by time         18.5 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.

function operator2!(
    C::Matrix{Tv},
    A::Matrix{Tv},
    alpha::Tv,
) where {Tv <: AbstractFloat}
    @. C += alpha * A
end


function operator3!(
    C::Matrix{Tv},
    A::Matrix{Tv},
    alpha::Tv,
) where {Tv <: AbstractFloat}
    @. C -= alpha * A
end

alpha = randn(Float64)

@benchmark operator2!($dir, $R, $alpha)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  307.719 ΞΌs … 881.443 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     308.976 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   309.498 ΞΌs Β±   7.404 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

     β–β–ƒβ–…β–†β–‡β–ˆβ–ˆβ–ˆβ–‡β–‡β–…β–„β–‚                             ▁ ▁▁▂▁▂▂▁▁▁▁▁    β–ƒ
  β–…β–β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–…β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ƒβ–ƒβ–…β–†β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆ
  308 ΞΌs        Histogram: log(frequency) by time        315 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.

@benchmark operator3!($dir, $R, $alpha)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  307.490 ΞΌs …  4.655 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     308.765 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   309.785 ΞΌs Β± 43.610 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

    β–β–ƒβ–„β–†β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–„β–‚β–‚β–β–‚β–„β–„β–„β–„β–ƒβ–‚β–                   ▁▁▁▂▁▂▁▂▁▁▁      β–ƒ
  β–…β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–…β–β–ƒβ–β–β–β–β–β–β–ƒβ–ƒβ–β–…β–…β–‡β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆ β–ˆ
  307 ΞΌs        Histogram: log(frequency) by time       316 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.

Since my code did 8 dot and 8 daxpy, I would expect it to have a running time slightly more than 2.5 ms, maybe 3ms. But now it’s 5.775ms. I suspect it’s because of the way I declare and instantiate LBFGSHistory but I’m not quite sure.

Any advice would be really helpful.

Did you profile the code to see which other functions you spend time in?
When I do it, nearly all of the time is spent in these two lines:

@. dir -= lbfgshis.vecs[j].y * Ξ± 
@. dir += lbfgshis.vecs[j].s * Ξ³ 

Replacing these operations with the more optimized

LinearAlgebra.axpy!(-Ξ±, dir, lbfgshis.vecs[j].y)
LinearAlgebra.axpy!(Ξ³, dir, lbfgshis.vecs[j].s)

yields a significant speedup (x3)

3 Likes

is the problem here that this code is missing the fastmath flag needed to vectorize and fma?

My main concern here was that daxpy is the most time-consuming operation I did and I did exactly 8 of them. And I benchmarked the way I wrote the daxpy, which took 370us on two matrices with shape 8000 * 41, but the whole procedure takes > 5ms, which is >> 370us * 8, so I was confused.

Thanks for the tip! The last two arguments of axpy! need to be exchanged.

1 Like

No, it looks like it’s that axpy! is multithreaded. If I do:

using BenchmarkTools
using LinearAlgebra: axpy!

function foo1!(dir, Ξ±, Ξ³, y, s)
    return @. dir += s * Ξ³ - y * Ξ±
end

@fastmath function foo2!(dir, Ξ±, Ξ³, y, s)
    return @. dir += s * Ξ³ - y * Ξ±
end

function foo3!(dir, Ξ±, Ξ³, y, s)
    return axpy!(Ξ³, s, axpy!(-Ξ±, y, dir))
end

dir = zeros(8000*41); y = copy(dir); s = copy(dir);

it gives:

julia> @btime foo1!($dir, 0.1, 0.2, $y, $s);
  162.673 ΞΌs (0 allocations: 0 bytes)

julia> @btime foo2!($dir, 0.1, 0.2, $y, $s); # uses @fastmath
  158.207 ΞΌs (0 allocations: 0 bytes)

julia> @btime foo3!($dir, 0.1, 0.2, $y, $s); # uses axpy!
  85.101 ΞΌs (0 allocations: 0 bytes)

This is with the default LinearAlgebra.BLAS.get_num_threads() == 3 on my computer. However, if I turn off BLAS multi-threading, it is slower:

julia> LinearAlgebra.BLAS.set_num_threads(1);

julia> @btime foo3!($dir, 0.1, 0.2, $y, $s);
  221.990 ΞΌs (0 allocations: 0 bytes)

probably because I did a single fused loop in foo1! and foo2!, unlike axpy! which requires two loops (and unlike the original code by @yhuang above).

If I use LoopVectorization.jl to multi-thread it:

using LoopVectorization

function foo4!(dir, Ξ±, Ξ³, y, s)
    length(dir) == length(y) == length(s) || throw(DimensionMismatch())
    @tturbo for i = 1:length(dir)
        dir[i] += s[i] * Ξ³ - y[i] * Ξ±
    end
    return dir
end

then (with julia -t 3 to also use 3 threads), I get:

julia> @btime foo4!($dir, 0.1, 0.2, $y, $s);
  59.584 ΞΌs (0 allocations: 0 bytes)

which is the fastest yet.

2 Likes