Performance help for this short matrix function?

linearalgebra

#1

I’m computing regularized optimal transport distances as described in https://papers.nips.cc/paper/4927-sinkhorn-distances-lightspeed-computation-of-optimal-transport.pdf

The code is quite short:

using LinearAlgebra

function wd_sinkhorn(x, y, dm, lambda = 100; iters = 10 * lambda)
    n = length(x)
    @assert n == length(y)
    K = exp.(-lambda .* dm)
    u = ones(n) / n
    v = ones(n) / n
    temp_loc = K * v
    for _ in 1:iters
        LinearAlgebra.mul!(temp_loc, K, v)
        u .= x ./ temp_loc
        LinearAlgebra.mul!(temp_loc, K', u)
        v .= y ./ temp_loc
    end
    p_lambda = Diagonal(u) * K * Diagonal(v)
    return sum(p_lambda .* dm)
end

This seems to work well, but not as fast as I’d like. x and y are float vectors, typically with length between 1,000 and 10,000. dm is a square distance matrix matching the length of x and y.

Using @time I can confirm this function doesn’t allocate much, but it’s still painfully slow for large inputs or iteration counts. Is there anything else major to be done?


#2

Did you profile your code? What are the bottlenecks?


#3

Does the number of iterations have to be fixed? Can’t you test if the solution has converged, and then break?


#4

I have, but I think the results are misleading. Profiling shows more than half the backtraces occurring in the line p_lambda = Diagonal(u) * K * Diagonal(v). However, actual execution time (measured by @elapsed) is almost perfectly proportional to the number of iterations, and that line happens only once. So I suspect that some core linear algebra calls (BLAS?) can’t return backtraces.


#5

Absolutely, for use in practice a convergence test on u is almost surely necessary. Nonetheless, just trying out various iteration counts by hand reveals that a couple thousand iterations are necessary, so it’s still desirable to make everything and especially the inner loop fast.


#6

Weird. You could also just do “ghetto” benchmarking by putting @time in front of lines / blocks of interest, or simply commenting lines out and re-running the test. It would be interesting to know how much time (in terms of total time) is spent in calls to mul!.


#7
function wd_sinkhorn_countmultime(x, y, dm, lambda = 100; iters = 10 * lambda)
    n = length(x)
    @assert n == length(y)
    K = exp.(-lambda .* dm)
    u = ones(n) / n
    v = ones(n) / n
    temp_loc = K * v
    multime = 0.0
    for _ in 1:iters
        multime += @elapsed LinearAlgebra.mul!(temp_loc, K, v)
        u .= x ./ temp_loc
        multime += @elapsed LinearAlgebra.mul!(temp_loc, K', u)
        v .= y ./ temp_loc
    end
    p_lambda = Diagonal(u) * K * Diagonal(v)
    @show multime
    return sum(p_lambda .* dm)
end

julia> @elapsed wd_sinkhorn_countmultime(x, y, dm, 200, iters = 1000)
multime = 3.127651118999997
3.359015272

So, almost all the time is spent in mul!?


#8

Yeah, looks like it. ProfileView shows that almost all the time is spent in BLAS gemv (called from the mul! lines), so I don’t think there’s any quick wins to be gained here. You could try compiling Julia with MKL instead of OpenBLAS if that’s an option for you.


#9

Yes, that’s also what I get:

julia> @profile wd_sinkhorn(x, y, dm)
julia> Profile.clear()
julia> @profile wd_sinkhorn(x, y, dm)
julia> Profile.print(format=:flat)
[...]
    24 /tmp/sinkhorn.jl   8  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
   485 /tmp/sinkhorn.jl  13  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
   434 /tmp/sinkhorn.jl  15  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
     1 /tmp/sinkhorn.jl  16  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
    27 /tmp/sinkhorn.jl  18  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
     2 /tmp/sinkhorn.jl  19  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
   973 /tmp/sinkhorn.jl   6  wd_sinkhorn(::Array{Float64,1}, ::Array{Float64...
   973 /tmp/sinkhorn.jl   6  wd_sinkhorn                                       

For reference, lines 13 & 15 correspond to the two mul! operations and take respectively approximately 50% and 45% of the time. Line 18 corresponds to the Diagonal(u) * ... line and takes approximately 3% of the time.


#10

Perhaps off topic, but is this now easier than compiling? I just found (but have not yet tried) MKL.jl:


#11

Actually, one quick win you can do is actually materializing the transpose of K and then making both matrix multiplications of the form A' * b, which is more cache-friendly than A * b, i.e.:

function wd_sinkhorn2(x, y, dm, lambda = 100; iters = 10 * lambda)
    n = length(x)
    @assert n == length(y)
    K = exp.(-lambda .* dm)
    Kt = copy(K')
    u = ones(n) / n
    v = ones(n) / n
    temp_loc = K * v
    for _ in 1:iters
        LinearAlgebra.mul!(temp_loc, Kt', v)
        u .= x ./ temp_loc
        LinearAlgebra.mul!(temp_loc, K', u)
        v .= y ./ temp_loc
    end
    p_lambda = Diagonal(u) * K * Diagonal(v)
    return sum(p_lambda .* dm)
end

Performance comparison:

Before:

julia> @btime wd_sinkhorn(x, y, dm) setup = begin
           n = 1000
           x = rand(n)
           y = rand(n)
           dm = rand(n, n)
       end
  117.842 ms (16 allocations: 30.56 MiB)

After:

julia> @btime wd_sinkhorn2(x, y, dm) setup = begin
           n = 1000
           x = rand(n)
           y = rand(n)
           dm = rand(n, n)
       end
  87.549 ms (18 allocations: 38.19 MiB)

#12

you may find this package very convenient for this purpose: