# How to compute 2D Wasserstein distance

Hello,

I would like to compute the Wasserstein distance between two discrete joint probability distributions. The package ExactOptimalTransport.jl does not appear to support 2D distance, but I wonder whether it can be accomplished in the discrete case by properly specifying the cost and distance function. In the 1 dimensional case, we have:

``````using Distributions

using ExactOptimalTransport

dist1 = Categorical([.2,.3,.1,.4])

dist2 = Categorical([.3,.2,.4,.1])

wd = wasserstein(dist1, dist2)
``````

In the 2D case for a 2X2 joint distribution, we have

``````joint_dist1 = [.2 .3;.1 .4]

joint_dist2 = [.3 .2;.4 .1]
``````

In my use case, I think Chebyshev distance would work for the cost. In other words, moving mass from one cell two another incurs a cost of 1 in the 2X2 case. Can someone provide some guidance for accomplishing this?

Thank you.

Maybe you could use the function `ot_cost` with a custom cost matrix similar to here:
Wasserstein non-negative matrix factorisation · OptimalTransport.jl (juliaoptimaltransport.github.io)

Thanks for your reply. It looks like `ot_cost` takes a function which accepts x and y coordinates, which I believe are indices in this case. To emulate the Chebyshev distance in the sample 2X2 case, the cost function returns 0 if the indices match and 1 if the indices do not match. Unfortunately, this did not work. It should return .1 in the example below, but continues to return .30.

``````using Distributions
using ExactOptimalTransport
using ExactOptimalTransport: ot_cost

dist1 = Categorical([.3,.2,.4,.1])

dist2 = Categorical([.4,.2,.4,.0])

function dist_func(x, y)
println("x \$x y \$y")
return x == y ? 0 : 1
end

wasserstein(dist1, dist2)

ot_cost(dist_func, dist1, dist2)
``````

Based on my understanding, the problem is that the distance is computed between adjacent indices sequentially. For example, to get the distance from 1 to 3, it computes the distance from 1 to 2, and then from 2 to 3. So I don’t think there is a way to compute the distance I need.

There should be better ways, but to get things done, in case of 2x2 matrices, and if I understand the problem correctly, the following might be a crude way to calculate the distance:

``````function otdist2x2(a_in, A_in)
a = copy(a_in)
A = copy(A_in)
ot_dist = 0.0
for i in 1:4
keep = min(a[i],A[i])
println("\$i -> \$i \$keep")
a[i] -= keep
A[i] -= keep
end
for (i,j) in [(2,4),(1,2),(3,1),(4,2),(1,3),(2,1),(3,4),(4,3)]
move = min(a[i],A[j])
println("\$i -> \$j \$move")
a[i] -= move
A[j] -= move
ot_dist += move
end
for (i,j) in [(1,4),(4,1),(3,2),(2,3)]
move = min(a[i], A[j])
println("\$i -> \$j \$move")
a[i] -= move
A[j] -= move
ot_dist += 2*move
end
ot_dist
end
``````

This code prints ‘debugging’ info to clarify operation. The `println` statements should be removed. Some examples:

``````julia> a = [1.0 0.0 ; 0.0 1.0]
2×2 Matrix{Float64}:
1.0  0.0
0.0  1.0

julia> A = [0.0 1.0 ; 1.0 0.0]
2×2 Matrix{Float64}:
0.0  1.0
1.0  0.0

julia> otdist2x2(a,A) # two very far matrices
...
2.0

julia> otdist2x2(a,a) # matrix is zero distance from itself
...
0.0

julia> u = [0.25 0.25; 0.25 0.25]
2×2 Matrix{Float64}:
0.25  0.25
0.25  0.25

julia> otdist2x2(a,u) # distance to uniform distribution
...
0.5

julia> otdist2x2(A,u) # should be the same
...
0.5
``````

Is this the value you are looking for?
If yes, the code can be streamlined.

Thank you for your reply. I think this is close based on some test examples. The following examples returned the expected value of .10:

``````p1 = [.3 .2;.4 .1]

p2 = [.3 .3;.4 .0]

otdist2x2(p1, p2)

p1 = [.3 .2;.4 .1]

p2 = [.3 .2;.5 .0]

otdist2x2(p1, p2)
``````

One example what was unexpected was the following:

``````p1 = [1 0;.0 0]

p2 = [0 0; 0 1]

otdist2x2(p1, p2)
``````

which returned 2, but I was expecting 1. Is there a bug, or did I make a reasoning error?

Ok. I think I understand now. It looks like you are using Manhattan distance, but what I am looking for is Chebyshev distance, which means if I go from indices [1,1] to [2,2], the distance is one (e.g., one move diagonally).

This is easily fixed, by just changing:

``````           ot_dist += 2*move
``````

at bottom of function to:

``````           ot_dist += move
``````

Now that the required result is established, I think it will be possible to recode this in a more elegant way.

1 Like

Thank you very much. I will continue testing, but it looks promising!

A new way is quite short:

``````otdist2x2(a_in, A_in) = sum(max.(0.0, a_in .- A_in))
``````

Does this work?

I changed it to `otdist2x2(a_in, A_in) = 1.0 - sum(max.(0.0, a_in .- A_in))` and so far it works. Thank you! I will report back if I find a problem with either implementation.

Also, it might be nice to have a general solution for discrete joint distributions. I might open an issue on ExactOptimalTransport.jl for this reason.

Try the new version (I edited the post a little in the 5m grace period), so it doesn’t show. I think the `1.0 - sum(..)` isn’t good (distance of dist from itself is an example).

And considering the simplicity of the logic, it should work for any discrete distribution, as long as the metric is the discrete topology metric (i.e. 0/1, same/different)

1 Like

Thank you again!. I posted some of my test cases below for future reference.

Summary
``````using Test

function otdist2x2(a_in, A_in)
a = copy(a_in)
A = copy(A_in)
ot_dist = 0.0
for i in 1:4
keep = min(a[i],A[i])
a[i] -= keep
A[i] -= keep
end
for (i,j) in [(2,4),(1,2),(3,1),(4,2),(1,3),(2,1),(3,4),(4,3)]
move = min(a[i],A[j])
println("\$i -> \$j \$move")
a[i] -= move
A[j] -= move
ot_dist += move
end
for (i,j) in [(1,4),(4,1),(3,2),(2,3)]
move = min(a[i], A[j])
a[i] -= move
A[j] -= move
ot_dist += move
end
ot_dist
end

otdist2x2(a_in, A_in) = sum(max.(0.0, a_in .- A_in))

expected1 = .1
p1 = [.3 .2;.4 .1]
p2 = [.3 .3;.4 .0]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected1 ≈ d1
@test expected1 ≈ d2

expected2 = .1
p1 = [.3 .2;.4 .1]
p2 = [.3 .2;.5 .0]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected2 ≈ d1
@test expected2 ≈ d2

expected3 = .30
p1 = [.3 .3;.4 .0]
p2 = [.2 .2; .3 .3]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected3 ≈ d1
@test expected3 ≈ d2

expected4 = 1
p1 = [1 0;.0 0]
p2 = [0 0; 0 1]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected4 ≈ d1
@test expected4 ≈ d2

expected5 = .75
p1 = [.25 .25;.25 .25]
p2 = [0 0; 0 1.]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected5 ≈ d1
@test expected5 ≈ d2

expected6 = .75
p1 = [.50 0.0;.25 .25]
p2 = [0 .5; .0 .5]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected6 ≈ d1
@test expected6 ≈ d2
``````