How to compute 2D Wasserstein distance

Hello,

I would like to compute the Wasserstein distance between two discrete joint probability distributions. The package ExactOptimalTransport.jl does not appear to support 2D distance, but I wonder whether it can be accomplished in the discrete case by properly specifying the cost and distance function. In the 1 dimensional case, we have:

using Distributions

using ExactOptimalTransport

dist1 = Categorical([.2,.3,.1,.4])

dist2 = Categorical([.3,.2,.4,.1])

wd = wasserstein(dist1, dist2)

In the 2D case for a 2X2 joint distribution, we have

joint_dist1 = [.2 .3;.1 .4]

joint_dist2 = [.3 .2;.4 .1]

In my use case, I think Chebyshev distance would work for the cost. In other words, moving mass from one cell two another incurs a cost of 1 in the 2X2 case. Can someone provide some guidance for accomplishing this?

Thank you.

Maybe you could use the function ot_cost with a custom cost matrix similar to here:
Wasserstein non-negative matrix factorisation · OptimalTransport.jl (juliaoptimaltransport.github.io)

Thanks for your reply. It looks like ot_cost takes a function which accepts x and y coordinates, which I believe are indices in this case. To emulate the Chebyshev distance in the sample 2X2 case, the cost function returns 0 if the indices match and 1 if the indices do not match. Unfortunately, this did not work. It should return .1 in the example below, but continues to return .30.

using Distributions
using ExactOptimalTransport
using ExactOptimalTransport: ot_cost

dist1 = Categorical([.3,.2,.4,.1])

dist2 = Categorical([.4,.2,.4,.0])

function dist_func(x, y)
    println("x $x y $y")
    return x == y ? 0 : 1
end

wasserstein(dist1, dist2)

ot_cost(dist_func, dist1, dist2)

Based on my understanding, the problem is that the distance is computed between adjacent indices sequentially. For example, to get the distance from 1 to 3, it computes the distance from 1 to 2, and then from 2 to 3. So I don’t think there is a way to compute the distance I need.

There should be better ways, but to get things done, in case of 2x2 matrices, and if I understand the problem correctly, the following might be a crude way to calculate the distance:

function otdist2x2(a_in, A_in)
    a = copy(a_in)
    A = copy(A_in)
    ot_dist = 0.0
    for i in 1:4
        keep = min(a[i],A[i])
        println("$i -> $i $keep")
        a[i] -= keep
        A[i] -= keep
    end
    for (i,j) in [(2,4),(1,2),(3,1),(4,2),(1,3),(2,1),(3,4),(4,3)]
        move = min(a[i],A[j])
        println("$i -> $j $move")
        a[i] -= move
        A[j] -= move
        ot_dist += move
    end
    for (i,j) in [(1,4),(4,1),(3,2),(2,3)]
        move = min(a[i], A[j])
        println("$i -> $j $move")
        a[i] -= move
        A[j] -= move
        ot_dist += 2*move
    end
    ot_dist
end

This code prints ‘debugging’ info to clarify operation. The println statements should be removed. Some examples:

julia> a = [1.0 0.0 ; 0.0 1.0]
2×2 Matrix{Float64}:
 1.0  0.0
 0.0  1.0

julia> A = [0.0 1.0 ; 1.0 0.0]
2×2 Matrix{Float64}:
 0.0  1.0
 1.0  0.0

julia> otdist2x2(a,A) # two very far matrices
...
2.0

julia> otdist2x2(a,a) # matrix is zero distance from itself
...
0.0

julia> u = [0.25 0.25; 0.25 0.25]
2×2 Matrix{Float64}:
 0.25  0.25
 0.25  0.25

julia> otdist2x2(a,u) # distance to uniform distribution
...
0.5

julia> otdist2x2(A,u) # should be the same
...
0.5

Is this the value you are looking for?
If yes, the code can be streamlined.

Thank you for your reply. I think this is close based on some test examples. The following examples returned the expected value of .10:

p1 = [.3 .2;.4 .1]

p2 = [.3 .3;.4 .0]

otdist2x2(p1, p2)

p1 = [.3 .2;.4 .1]

p2 = [.3 .2;.5 .0]

otdist2x2(p1, p2)

One example what was unexpected was the following:

p1 = [1 0;.0 0]

p2 = [0 0; 0 1]

otdist2x2(p1, p2)

which returned 2, but I was expecting 1. Is there a bug, or did I make a reasoning error?

Ok. I think I understand now. It looks like you are using Manhattan distance, but what I am looking for is Chebyshev distance, which means if I go from indices [1,1] to [2,2], the distance is one (e.g., one move diagonally).

This is easily fixed, by just changing:

           ot_dist += 2*move

at bottom of function to:

           ot_dist += move

Now that the required result is established, I think it will be possible to recode this in a more elegant way.

1 Like

Thank you very much. I will continue testing, but it looks promising!

A new way is quite short:

otdist2x2(a_in, A_in) = sum(max.(0.0, a_in .- A_in))

Does this work?

I changed it to otdist2x2(a_in, A_in) = 1.0 - sum(max.(0.0, a_in .- A_in)) and so far it works. Thank you! I will report back if I find a problem with either implementation.

Also, it might be nice to have a general solution for discrete joint distributions. I might open an issue on ExactOptimalTransport.jl for this reason.

Try the new version (I edited the post a little in the 5m grace period), so it doesn’t show. I think the 1.0 - sum(..) isn’t good (distance of dist from itself is an example).

And considering the simplicity of the logic, it should work for any discrete distribution, as long as the metric is the discrete topology metric (i.e. 0/1, same/different)

1 Like

Thank you again!. I posted some of my test cases below for future reference.

Summary
using Test 

function otdist2x2(a_in, A_in)
    a = copy(a_in)
    A = copy(A_in)
    ot_dist = 0.0
    for i in 1:4
        keep = min(a[i],A[i])
        a[i] -= keep
        A[i] -= keep
    end
    for (i,j) in [(2,4),(1,2),(3,1),(4,2),(1,3),(2,1),(3,4),(4,3)]
        move = min(a[i],A[j])
        println("$i -> $j $move")
        a[i] -= move
        A[j] -= move
        ot_dist += move
    end
    for (i,j) in [(1,4),(4,1),(3,2),(2,3)]
        move = min(a[i], A[j])
        a[i] -= move
        A[j] -= move
        ot_dist += move
    end
    ot_dist
end

otdist2x2(a_in, A_in) = sum(max.(0.0, a_in .- A_in))

expected1 = .1
p1 = [.3 .2;.4 .1]
p2 = [.3 .3;.4 .0]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected1 ≈ d1 
@test expected1 ≈ d2

expected2 = .1
p1 = [.3 .2;.4 .1]
p2 = [.3 .2;.5 .0]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected2 ≈ d1 
@test expected2 ≈ d2

expected3 = .30
p1 = [.3 .3;.4 .0]
p2 = [.2 .2; .3 .3]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected3 ≈ d1 
@test expected3 ≈ d2

expected4 = 1
p1 = [1 0;.0 0]
p2 = [0 0; 0 1]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected4 ≈ d1 
@test expected4 ≈ d2

expected5 = .75
p1 = [.25 .25;.25 .25]
p2 = [0 0; 0 1.]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected5 ≈ d1 
@test expected5 ≈ d2

expected6 = .75
p1 = [.50 0.0;.25 .25]
p2 = [0 .5; .0 .5]
d1 = otdist2x2(p1, p2)
d2 = otdist2x2_elegant(p1, p2)
@test expected6 ≈ d1 
@test expected6 ≈ d2