Computing discrete Wasserstein (EMD) distance in Julia?

question

#1

Is there any library in Julia for computing discrete Wasserstein (EMD) given two discrete distributions?
Python seems to have a tool for that http://pot.readthedocs.io/en/stable/index.html, but if a pure Julia solution is available, it will be better than using PyCall.


#2

Thanks for the link, that looks cool! Unfortunately, I don’t know of a julia package for optimal transport.

From a (very) short glance at the source, it looks like POT is doing most of the numerical heavy lifting in C/C++ libraries, but the python code is also non-trivial.

If this is correct, then PyCall looks like a good choice (no big speed-up if POT is already using fast C code; calling the C-libs from julia would be non-trivial because the python code is not just a thin wrapper but contains real logic). Of course, if you want to experiment with new algorithms for optimal transport, then this state of affairs sucks (two language problem); also it costs you heavy dependencies.

Is this assessment correct? You probably know more about POT than me.

I never used POT, but they do offer a buffet of solvers, spent the effort to offload a lot of computation into faster languages, and the choices like sinkhorn vs direct solver are probably more relevant than choice of programming language anyway (algorithm class very often beats implementation / hardware). Not sure how fast they are compared to other existing libraries.

Or are you in the situation where you want to solve many cheap OT problems instead of few expensive ones? Then you are in a bad situation, I fear :frowning:


#3

I am also new to POT. I had a look at their code and made similar observations as you. Their python code portions are not trivial. I am not sure if that will cause performance problem.

Yea, I am in the situation where I want to solve many cheap OT problems a lot of time :frowning:


#4

Yea, I am in the situation where I want to solve many cheap OT problems a lot of time 

Merde. More specifically, are you by chance planning to run KNN-like constructions on a metric space, where each experiment corresponds to one point-cloud of samples, and the metric (between experiments) is EMD, such that you need, in worst worst-case, quadratically many EMD solutions? That’s my guess because of the Machine Learning Tag (learn distributions, not medium-dimensional points) combined with “many cheap” problems.

I have met people running such computations; they were very, very unhappy with their compute needs (context was some medical data that came in form of a (sampled) distribution for each patient; so you have a partially labeled distribution-of-distributions).

I am pretty sure that there is a lot one could do in these settings; but this is algorithmic research rather than coding or googling for packages. It’s a fun problem, though, that I spent some off-time working on.

Edit: If you end up writing a paper-thin wrapper using PyCall around POT, it would be much appreciated if you could make a github repo for it. That way you can also swap out the POT for another solver if a faster one can be found.


#5

just found this thread and wanted to chime in about my experience (hoping I am on point really):
that’s pretty much what I have - targeting EMD/Wasserstein distance for vectors similarity (sequences of word vectors that is). As it’s quite expensive to compute in real-time I model it via a nn: two sequences inputs, a siamese encoder part (weights sharing) and a custom target of which EMD is a component - basically encoding such sequences as dense vector representations which then are queried live - knn over angular distances, superbit lsh. To sum it up, what I did was pre-compute/cache then approximate it afterwards via nn - highly interested in flux to be able to look at this end-to-end.


#6

Ok, can you explain again, also giving orders of magnitude for sizes/dimensions?

E.g. your distributions (sentences, for the earth movers distance) contain 10-100 points each (e.g. 10-1000 dim output of word2vec for each word in a sentence, using cosine distance), and you want to use an approximate k-nearest-neighbor classifier employing the EMD-distance. knn-classifiers do a lot of distance evaluations, and EMD is expensive. Therefore, you instead train a neural network NN_map to map sentence->some_Rk, such that EMD(sentence1, sentence2) ~ dist_cosine(NN_map(sentence1), NN_map(sentence2)). Did I get this right?


#7

From this discussion, I found this package


#8

precisely. In my case these sentences are sets of skills - “java + clojure + …” and to be complete in practice (although not very relevant for this thread) my custom target relies on several terms in its computation (eq. one quantity is 1/(1+emd), another quantity is divergence in LDA distributions,…). In total I have ~3M such encoded sentence vectors in 300d with vectors being stored offline, knn afterwards.


#9

So, to clarify: Each sentence-vector has ~300 entries. Each entry is itself a weight and a point in a fixed metric space M. Then you would naively need to solve (or approximate) a lot (billions) of 300 x 300 EMD instances in M. M itself is a moderately small finite metric space (the set of possible “skills”, where you presumably learn the metric on M in a separate step). Luckily, M fulfills the triangle inequality which may or may not help your EMD solver.

Because you jointly approximate EMD + other_target_components by Euclidean-after-NN_map (or for cosine, Euclidean-after-normalization-after-NN_map), you are free to use kd-trees or similar for your knn-search, and don’t need to drop down to general metric space search trees like e.g. covertree (you already run a svd in front of the search, right? kd-tree is not isometry invariant and runs a lot better if you align the coordinate axes with the most relevant dimensions).

Did I get this right?

Interesting, I hadn’t encountered this kind of problem before. I have mostly thought about infinite and more geometric apriori known metric spaces M: e.g. large vectorspace and your sentences for EMD happen to be mostly supported on some lower-dimensional “data manifold” (not really a manifold), which you of course don’t know apriori. I’m originally coming from the perspective of trying to understand and treat the data manifold as a geometric object. As such, I have unfortunately nothing interesting to say about your problem, except that we probably should use very different approaches :wink: