using LinearAlgebra
function wd_sinkhorn(x, y, dm, lambda = 100; iters = 10 * lambda)
n = length(x)
@assert n == length(y)
K = exp.(-lambda .* dm)
u = ones(n) / n
v = ones(n) / n
temp_loc = K * v
for _ in 1:iters
LinearAlgebra.mul!(temp_loc, K, v)
u .= x ./ temp_loc
LinearAlgebra.mul!(temp_loc, K', u)
v .= y ./ temp_loc
end
p_lambda = Diagonal(u) * K * Diagonal(v)
return sum(p_lambda .* dm)
end
This seems to work well, but not as fast as I’d like. x and y are float vectors, typically with length between 1,000 and 10,000. dm is a square distance matrix matching the length of x and y.
Using @time I can confirm this function doesn’t allocate much, but it’s still painfully slow for large inputs or iteration counts. Is there anything else major to be done?
I have, but I think the results are misleading. Profiling shows more than half the backtraces occurring in the line p_lambda = Diagonal(u) * K * Diagonal(v). However, actual execution time (measured by @elapsed) is almost perfectly proportional to the number of iterations, and that line happens only once. So I suspect that some core linear algebra calls (BLAS?) can’t return backtraces.
Absolutely, for use in practice a convergence test on u is almost surely necessary. Nonetheless, just trying out various iteration counts by hand reveals that a couple thousand iterations are necessary, so it’s still desirable to make everything and especially the inner loop fast.
Weird. You could also just do “ghetto” benchmarking by putting @time in front of lines / blocks of interest, or simply commenting lines out and re-running the test. It would be interesting to know how much time (in terms of total time) is spent in calls to mul!.
function wd_sinkhorn_countmultime(x, y, dm, lambda = 100; iters = 10 * lambda)
n = length(x)
@assert n == length(y)
K = exp.(-lambda .* dm)
u = ones(n) / n
v = ones(n) / n
temp_loc = K * v
multime = 0.0
for _ in 1:iters
multime += @elapsed LinearAlgebra.mul!(temp_loc, K, v)
u .= x ./ temp_loc
multime += @elapsed LinearAlgebra.mul!(temp_loc, K', u)
v .= y ./ temp_loc
end
p_lambda = Diagonal(u) * K * Diagonal(v)
@show multime
return sum(p_lambda .* dm)
end
julia> @elapsed wd_sinkhorn_countmultime(x, y, dm, 200, iters = 1000)
multime = 3.127651118999997
3.359015272
Yeah, looks like it. ProfileView shows that almost all the time is spent in BLAS gemv (called from the mul! lines), so I don’t think there’s any quick wins to be gained here. You could try compiling Julia with MKL instead of OpenBLAS if that’s an option for you.
For reference, lines 13 & 15 correspond to the two mul! operations and take respectively approximately 50% and 45% of the time. Line 18 corresponds to the Diagonal(u) * ... line and takes approximately 3% of the time.
Actually, one quick win you can do is actually materializing the transpose of K and then making both matrix multiplications of the form A' * b, which is more cache-friendly than A * b, i.e.:
function wd_sinkhorn2(x, y, dm, lambda = 100; iters = 10 * lambda)
n = length(x)
@assert n == length(y)
K = exp.(-lambda .* dm)
Kt = copy(K')
u = ones(n) / n
v = ones(n) / n
temp_loc = K * v
for _ in 1:iters
LinearAlgebra.mul!(temp_loc, Kt', v)
u .= x ./ temp_loc
LinearAlgebra.mul!(temp_loc, K', u)
v .= y ./ temp_loc
end
p_lambda = Diagonal(u) * K * Diagonal(v)
return sum(p_lambda .* dm)
end
Performance comparison:
Before:
julia> @btime wd_sinkhorn(x, y, dm) setup = begin
n = 1000
x = rand(n)
y = rand(n)
dm = rand(n, n)
end
117.842 ms (16 allocations: 30.56 MiB)
After:
julia> @btime wd_sinkhorn2(x, y, dm) setup = begin
n = 1000
x = rand(n)
y = rand(n)
dm = rand(n, n)
end
87.549 ms (18 allocations: 38.19 MiB)