Performance help for this short matrix function?

I’m computing regularized optimal transport distances as described in https://papers.nips.cc/paper/4927-sinkhorn-distances-lightspeed-computation-of-optimal-transport.pdf

The code is quite short:

using LinearAlgebra

function wd_sinkhorn(x, y, dm, lambda = 100; iters = 10 * lambda)
    n = length(x)
    @assert n == length(y)
    K = exp.(-lambda .* dm)
    u = ones(n) / n
    v = ones(n) / n
    temp_loc = K * v
    for _ in 1:iters
        LinearAlgebra.mul!(temp_loc, K, v)
        u .= x ./ temp_loc
        LinearAlgebra.mul!(temp_loc, K', u)
        v .= y ./ temp_loc
    end
    p_lambda = Diagonal(u) * K * Diagonal(v)
    return sum(p_lambda .* dm)
end

This seems to work well, but not as fast as I’d like. x and y are float vectors, typically with length between 1,000 and 10,000. dm is a square distance matrix matching the length of x and y.

Using @time I can confirm this function doesn’t allocate much, but it’s still painfully slow for large inputs or iteration counts. Is there anything else major to be done?

Did you profile your code? What are the bottlenecks?

2 Likes

Does the number of iterations have to be fixed? Can’t you test if the solution has converged, and then break?

I have, but I think the results are misleading. Profiling shows more than half the backtraces occurring in the line p_lambda = Diagonal(u) * K * Diagonal(v). However, actual execution time (measured by @elapsed) is almost perfectly proportional to the number of iterations, and that line happens only once. So I suspect that some core linear algebra calls (BLAS?) can’t return backtraces.

Absolutely, for use in practice a convergence test on u is almost surely necessary. Nonetheless, just trying out various iteration counts by hand reveals that a couple thousand iterations are necessary, so it’s still desirable to make everything and especially the inner loop fast.

Weird. You could also just do “ghetto” benchmarking by putting @time in front of lines / blocks of interest, or simply commenting lines out and re-running the test. It would be interesting to know how much time (in terms of total time) is spent in calls to mul!.

function wd_sinkhorn_countmultime(x, y, dm, lambda = 100; iters = 10 * lambda)
    n = length(x)
    @assert n == length(y)
    K = exp.(-lambda .* dm)
    u = ones(n) / n
    v = ones(n) / n
    temp_loc = K * v
    multime = 0.0
    for _ in 1:iters
        multime += @elapsed LinearAlgebra.mul!(temp_loc, K, v)
        u .= x ./ temp_loc
        multime += @elapsed LinearAlgebra.mul!(temp_loc, K', u)
        v .= y ./ temp_loc
    end
    p_lambda = Diagonal(u) * K * Diagonal(v)
    @show multime
    return sum(p_lambda .* dm)
end

julia> @elapsed wd_sinkhorn_countmultime(x, y, dm, 200, iters = 1000)
multime = 3.127651118999997
3.359015272

So, almost all the time is spent in mul!?

Yeah, looks like it. ProfileView shows that almost all the time is spent in BLAS gemv (called from the mul! lines), so I don’t think there’s any quick wins to be gained here. You could try compiling Julia with MKL instead of OpenBLAS if that’s an option for you.

1 Like

Yes, that’s also what I get:

julia> @profile wd_sinkhorn(x, y, dm)
julia> Profile.clear()
julia> @profile wd_sinkhorn(x, y, dm)
julia> Profile.print(format=:flat)
[...]
    24 /tmp/sinkhorn.jl   8  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
   485 /tmp/sinkhorn.jl  13  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
   434 /tmp/sinkhorn.jl  15  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
     1 /tmp/sinkhorn.jl  16  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
    27 /tmp/sinkhorn.jl  18  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
     2 /tmp/sinkhorn.jl  19  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
   973 /tmp/sinkhorn.jl   6  wd_sinkhorn(::Array{Float64,1}, ::Array{Float64...
   973 /tmp/sinkhorn.jl   6  wd_sinkhorn                                       

For reference, lines 13 & 15 correspond to the two mul! operations and take respectively approximately 50% and 45% of the time. Line 18 corresponds to the Diagonal(u) * ... line and takes approximately 3% of the time.

Perhaps off topic, but is this now easier than compiling? I just found (but have not yet tried) MKL.jl:

Actually, one quick win you can do is actually materializing the transpose of K and then making both matrix multiplications of the form A' * b, which is more cache-friendly than A * b, i.e.:

function wd_sinkhorn2(x, y, dm, lambda = 100; iters = 10 * lambda)
    n = length(x)
    @assert n == length(y)
    K = exp.(-lambda .* dm)
    Kt = copy(K')
    u = ones(n) / n
    v = ones(n) / n
    temp_loc = K * v
    for _ in 1:iters
        LinearAlgebra.mul!(temp_loc, Kt', v)
        u .= x ./ temp_loc
        LinearAlgebra.mul!(temp_loc, K', u)
        v .= y ./ temp_loc
    end
    p_lambda = Diagonal(u) * K * Diagonal(v)
    return sum(p_lambda .* dm)
end

Performance comparison:

Before:

julia> @btime wd_sinkhorn(x, y, dm) setup = begin
           n = 1000
           x = rand(n)
           y = rand(n)
           dm = rand(n, n)
       end
  117.842 ms (16 allocations: 30.56 MiB)

After:

julia> @btime wd_sinkhorn2(x, y, dm) setup = begin
           n = 1000
           x = rand(n)
           y = rand(n)
           dm = rand(n, n)
       end
  87.549 ms (18 allocations: 38.19 MiB)
8 Likes

you may find this package very convenient for this purpose:

2 Likes