Computing discrete Wasserstein (EMD) distance in Julia?



Is there any library in Julia for computing discrete Wasserstein (EMD) given two discrete distributions?
Python seems to have a tool for that, but if a pure Julia solution is available, it will be better than using PyCall.


Thanks for the link, that looks cool! Unfortunately, I don’t know of a julia package for optimal transport.

From a (very) short glance at the source, it looks like POT is doing most of the numerical heavy lifting in C/C++ libraries, but the python code is also non-trivial.

If this is correct, then PyCall looks like a good choice (no big speed-up if POT is already using fast C code; calling the C-libs from julia would be non-trivial because the python code is not just a thin wrapper but contains real logic). Of course, if you want to experiment with new algorithms for optimal transport, then this state of affairs sucks (two language problem); also it costs you heavy dependencies.

Is this assessment correct? You probably know more about POT than me.

I never used POT, but they do offer a buffet of solvers, spent the effort to offload a lot of computation into faster languages, and the choices like sinkhorn vs direct solver are probably more relevant than choice of programming language anyway (algorithm class very often beats implementation / hardware). Not sure how fast they are compared to other existing libraries.

Or are you in the situation where you want to solve many cheap OT problems instead of few expensive ones? Then you are in a bad situation, I fear :frowning:


I am also new to POT. I had a look at their code and made similar observations as you. Their python code portions are not trivial. I am not sure if that will cause performance problem.

Yea, I am in the situation where I want to solve many cheap OT problems a lot of time :frowning:


Yea, I am in the situation where I want to solve many cheap OT problems a lot of time 

Merde. More specifically, are you by chance planning to run KNN-like constructions on a metric space, where each experiment corresponds to one point-cloud of samples, and the metric (between experiments) is EMD, such that you need, in worst worst-case, quadratically many EMD solutions? That’s my guess because of the Machine Learning Tag (learn distributions, not medium-dimensional points) combined with “many cheap” problems.

I have met people running such computations; they were very, very unhappy with their compute needs (context was some medical data that came in form of a (sampled) distribution for each patient; so you have a partially labeled distribution-of-distributions).

I am pretty sure that there is a lot one could do in these settings; but this is algorithmic research rather than coding or googling for packages. It’s a fun problem, though, that I spent some off-time working on.

Edit: If you end up writing a paper-thin wrapper using PyCall around POT, it would be much appreciated if you could make a github repo for it. That way you can also swap out the POT for another solver if a faster one can be found.