# Performance help for this short matrix function?

I’m computing regularized optimal transport distances as described in https://papers.nips.cc/paper/4927-sinkhorn-distances-lightspeed-computation-of-optimal-transport.pdf

The code is quite short:

``````using LinearAlgebra

function wd_sinkhorn(x, y, dm, lambda = 100; iters = 10 * lambda)
n = length(x)
@assert n == length(y)
K = exp.(-lambda .* dm)
u = ones(n) / n
v = ones(n) / n
temp_loc = K * v
for _ in 1:iters
LinearAlgebra.mul!(temp_loc, K, v)
u .= x ./ temp_loc
LinearAlgebra.mul!(temp_loc, K', u)
v .= y ./ temp_loc
end
p_lambda = Diagonal(u) * K * Diagonal(v)
return sum(p_lambda .* dm)
end
``````

This seems to work well, but not as fast as I’d like. `x` and `y` are float vectors, typically with length between 1,000 and 10,000. `dm` is a square distance matrix matching the length of `x` and `y`.

Using `@time` I can confirm this function doesn’t allocate much, but it’s still painfully slow for large inputs or iteration counts. Is there anything else major to be done?

Did you profile your code? What are the bottlenecks?

2 Likes

Does the number of iterations have to be fixed? Can’t you test if the solution has converged, and then break?

I have, but I think the results are misleading. Profiling shows more than half the backtraces occurring in the line `p_lambda = Diagonal(u) * K * Diagonal(v)`. However, actual execution time (measured by `@elapsed`) is almost perfectly proportional to the number of iterations, and that line happens only once. So I suspect that some core linear algebra calls (BLAS?) can’t return backtraces.

Absolutely, for use in practice a convergence test on `u` is almost surely necessary. Nonetheless, just trying out various iteration counts by hand reveals that a couple thousand iterations are necessary, so it’s still desirable to make everything and especially the inner loop fast.

Weird. You could also just do “ghetto” benchmarking by putting `@time` in front of lines / blocks of interest, or simply commenting lines out and re-running the test. It would be interesting to know how much time (in terms of total time) is spent in calls to `mul!`.

``````function wd_sinkhorn_countmultime(x, y, dm, lambda = 100; iters = 10 * lambda)
n = length(x)
@assert n == length(y)
K = exp.(-lambda .* dm)
u = ones(n) / n
v = ones(n) / n
temp_loc = K * v
multime = 0.0
for _ in 1:iters
multime += @elapsed LinearAlgebra.mul!(temp_loc, K, v)
u .= x ./ temp_loc
multime += @elapsed LinearAlgebra.mul!(temp_loc, K', u)
v .= y ./ temp_loc
end
p_lambda = Diagonal(u) * K * Diagonal(v)
@show multime
return sum(p_lambda .* dm)
end

julia> @elapsed wd_sinkhorn_countmultime(x, y, dm, 200, iters = 1000)
multime = 3.127651118999997
3.359015272
``````

So, almost all the time is spent in `mul!`?

Yeah, looks like it. ProfileView shows that almost all the time is spent in BLAS `gemv` (called from the `mul!` lines), so I don’t think there’s any quick wins to be gained here. You could try compiling Julia with MKL instead of OpenBLAS if that’s an option for you.

1 Like

Yes, that’s also what I get:

``````julia> @profile wd_sinkhorn(x, y, dm)
julia> Profile.clear()
julia> @profile wd_sinkhorn(x, y, dm)
julia> Profile.print(format=:flat)
[...]
24 /tmp/sinkhorn.jl   8  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
485 /tmp/sinkhorn.jl  13  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
434 /tmp/sinkhorn.jl  15  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
1 /tmp/sinkhorn.jl  16  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
27 /tmp/sinkhorn.jl  18  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
2 /tmp/sinkhorn.jl  19  #wd_sinkhorn#3(::Int64, ::Function, ::Array{Flo...
973 /tmp/sinkhorn.jl   6  wd_sinkhorn(::Array{Float64,1}, ::Array{Float64...
973 /tmp/sinkhorn.jl   6  wd_sinkhorn
``````

For reference, lines 13 & 15 correspond to the two `mul!` operations and take respectively approximately 50% and 45% of the time. Line 18 corresponds to the `Diagonal(u) * ...` line and takes approximately 3% of the time.

Perhaps off topic, but is this now easier than compiling? I just found (but have not yet tried) MKL.jl:

Actually, one quick win you can do is actually materializing the transpose of `K` and then making both matrix multiplications of the form `A' * b`, which is more cache-friendly than `A * b`, i.e.:

``````function wd_sinkhorn2(x, y, dm, lambda = 100; iters = 10 * lambda)
n = length(x)
@assert n == length(y)
K = exp.(-lambda .* dm)
Kt = copy(K')
u = ones(n) / n
v = ones(n) / n
temp_loc = K * v
for _ in 1:iters
LinearAlgebra.mul!(temp_loc, Kt', v)
u .= x ./ temp_loc
LinearAlgebra.mul!(temp_loc, K', u)
v .= y ./ temp_loc
end
p_lambda = Diagonal(u) * K * Diagonal(v)
return sum(p_lambda .* dm)
end
``````

Performance comparison:

Before:

``````julia> @btime wd_sinkhorn(x, y, dm) setup = begin
n = 1000
x = rand(n)
y = rand(n)
dm = rand(n, n)
end
117.842 ms (16 allocations: 30.56 MiB)
``````

After:

``````julia> @btime wd_sinkhorn2(x, y, dm) setup = begin
n = 1000
x = rand(n)
y = rand(n)
dm = rand(n, n)
end
87.549 ms (18 allocations: 38.19 MiB)
``````
8 Likes

you may find this package very convenient for this purpose:

2 Likes