Skip to content

POT generating incorrect result for very simple OT problem #345

Closed
@jstac

Description

@jstac

Apologies if I'm doing something stupid --- I don't think I am. The simple example

import numpy as np
import ot

phi = np.array((0.5, 0.5))   # distribution 1
psi = np.array((0.5, 0.5))   # distribution 2
c = ((2, 1),
     (1, 1))
c = np.array(c)

pi = ot.emd(phi, psi, c)

produces the incorrect result

array([[0, 0],
       [0, 0]])

(Clearly we should send all mass at 1 to 2 and all mass at 2 to 1.)

Direct application of linear programming produces the correct result

array([[ 0. ,  0.5],
       [ 0.5, -0. ]])

Here's the direct linear programming code

# Define parameters
m = n = 2

# Vectorize matrix C
c_vec = c.reshape((m * n, 1), order='F')

# Construct matrix A by Kronecker product
A1 = np.kron(np.ones((1, n)), np.identity(m))
A2 = np.kron(np.identity(n), np.ones((1, m)))
A  = np.vstack([A1, A2])

# Construct vector b
b = np.hstack([phi, psi])

# Solve the primal problem
res = linprog(c_vec, A_eq=A, b_eq=b, method='highs-ipm')

# Print results
pi = res.x.reshape((m,n), order='F')

Environment (please complete the following information):

Manjaro linux, POT installed via pip in Anaconda environment.

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Linux-5.13.19-2-MANJARO-x86_64-with-glibc2.17
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
[GCC 7.5.0]
NumPy 1.20.3
SciPy 1.7.1
POT 0.8.0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions