I’m computing regularized optimal transport distances as described in https://papers.nips.cc/paper/4927-sinkhorn-distances-lightspeed-computation-of-optimal-transport.pdf
The code is quite short:
using LinearAlgebra function wd_sinkhorn(x, y, dm, lambda = 100; iters = 10 * lambda) n = length(x) @assert n == length(y) K = exp.(-lambda .* dm) u = ones(n) / n v = ones(n) / n temp_loc = K * v for _ in 1:iters LinearAlgebra.mul!(temp_loc, K, v) u .= x ./ temp_loc LinearAlgebra.mul!(temp_loc, K', u) v .= y ./ temp_loc end p_lambda = Diagonal(u) * K * Diagonal(v) return sum(p_lambda .* dm) end
This seems to work well, but not as fast as I’d like.
y are float vectors, typically with length between 1,000 and 10,000.
dm is a square distance matrix matching the length of
@time I can confirm this function doesn’t allocate much, but it’s still painfully slow for large inputs or iteration counts. Is there anything else major to be done?