Closed
Description
Apologies if I'm doing something stupid --- I don't think I am. The simple example
import numpy as np
import ot
phi = np.array((0.5, 0.5)) # distribution 1
psi = np.array((0.5, 0.5)) # distribution 2
c = ((2, 1),
(1, 1))
c = np.array(c)
pi = ot.emd(phi, psi, c)
produces the incorrect result
array([[0, 0],
[0, 0]])
(Clearly we should send all mass at 1 to 2 and all mass at 2 to 1.)
Direct application of linear programming produces the correct result
array([[ 0. , 0.5],
[ 0.5, -0. ]])
Here's the direct linear programming code
# Define parameters
m = n = 2
# Vectorize matrix C
c_vec = c.reshape((m * n, 1), order='F')
# Construct matrix A by Kronecker product
A1 = np.kron(np.ones((1, n)), np.identity(m))
A2 = np.kron(np.identity(n), np.ones((1, m)))
A = np.vstack([A1, A2])
# Construct vector b
b = np.hstack([phi, psi])
# Solve the primal problem
res = linprog(c_vec, A_eq=A, b_eq=b, method='highs-ipm')
# Print results
pi = res.x.reshape((m,n), order='F')
Environment (please complete the following information):
Manjaro linux, POT installed via pip in Anaconda environment.
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.13.19-2-MANJARO-x86_64-with-glibc2.17
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
[GCC 7.5.0]
NumPy 1.20.3
SciPy 1.7.1
POT 0.8.0