Skip to content

[MRG] Fix bug in emd2 with empty weighs on backends #606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Releases

## 0.9.3
## 0.9.3dev

#### New features
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster.
Expand All @@ -9,6 +9,7 @@
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
- Fix doc and example for lowrank sinkhorn (PR #601)
- Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534)

## 0.9.2
*December 2023*
Expand Down
2 changes: 1 addition & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
# utils functions
from .utils import dist, unif, tic, toc, toq

__version__ = "0.9.3"
__version__ = "0.9.3dev"

__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
Expand Down
2 changes: 0 additions & 2 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,6 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,

"""
if nx is None:
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)

if isinstance(M, int) or isinstance(M, float):
nx = get_backend(G, deltaG, C1, C2)
else:
Expand Down
2 changes: 0 additions & 2 deletions ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,6 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
Gromov-Wasserstein". NeurIPS 2023 Workshop OTML.
"""
if nx is None:
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)

if isinstance(M, int) or isinstance(M, float):
nx = get_backend(G, deltaG, C1, C2)
else:
Expand Down
50 changes: 30 additions & 20 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,24 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c
ot.optim.cg : General regularized OT
"""

# convert to numpy if list
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)

a0, b0, M0 = a, b, M
if len(a0) != 0:
type_as = a0
elif len(b0) != 0:
type_as = b0
if len(a) != 0:
type_as = a
elif len(b) != 0:
type_as = b
else:
type_as = M0
nx = get_backend(M0, a0, b0)
type_as = M

# if empty array given then use uniform distributions
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]

# store original tensors
a0, b0, M0 = a, b, M

# convert to numpy
M, a, b = nx.to_numpy(M, a, b)
Expand Down Expand Up @@ -474,15 +481,23 @@ def emd2(a, b, M, processes=1,
"""

a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)

a0, b0, M0 = a, b, M
if len(a0) != 0:
type_as = a0
elif len(b0) != 0:
type_as = b0
if len(a) != 0:
type_as = a
elif len(b) != 0:
type_as = b
else:
type_as = M0
nx = get_backend(M0, a0, b0)
type_as = M

# if empty array given then use uniform distributions
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]

# store original tensors
a0, b0, M0 = a, b, M

# convert to numpy
M, a, b = nx.to_numpy(M, a, b)
Expand All @@ -491,11 +506,6 @@ def emd2(a, b, M, processes=1,
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64, order='C')

# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]

assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
Expand Down
6 changes: 5 additions & 1 deletion ot/lp/solver_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
transportation matrix)
"""
a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
x_a, x_b = list_to_array(x_a, x_b)
nx = get_backend(x_a, x_b)
if a is not None:
a = list_to_array(a, nx=nx)
if b is not None:
b = list_to_array(b, nx=nx)

assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
"emd_1d should only be used with monodimensional data"
Expand Down
4 changes: 1 addition & 3 deletions ot/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import warnings
from .lp import emd
from .bregman import sinkhorn
from .utils import list_to_array
from .backend import get_backend

with warnings.catch_warnings():
Expand Down Expand Up @@ -73,7 +72,6 @@ def line_search_armijo(

"""
if nx is None:
xk, pk, gfk = list_to_array(xk, pk, gfk)
xk0, pk0 = xk, pk
nx = get_backend(xk0, pk0)
else:
Expand Down Expand Up @@ -236,7 +234,7 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea
ot.lp.emd : Unregularized optimal transport
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
a, b, M, G0 = list_to_array(a, b, M, G0)

if isinstance(M, int) or isinstance(M, float):
nx = get_backend(a, b)
else:
Expand Down
24 changes: 21 additions & 3 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,30 @@ def laplacian(x):
return L


def list_to_array(*lst):
def list_to_array(*lst, nx=None):
r""" Convert a list if in numpy format """
lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)]
if nx is None: # find backend

if len(lst_not_empty) == 0:
type_as = np.zeros(0)
nx = get_backend(type_as)
else:
nx = get_backend(*lst_not_empty)
type_as = lst_not_empty[0]
else:
if len(lst_not_empty) == 0:
type_as = None
else:
type_as = lst_not_empty[0]
if len(lst) > 1:
return [np.array(a) if isinstance(a, list) else a for a in lst]
return [nx.from_numpy(np.array(a), type_as=type_as)
if isinstance(a, list) else a for a in lst]
else:
return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
if isinstance(lst[0], list):
return nx.from_numpy(np.array(lst[0]), type_as=type_as)
else:
return lst[0]


def proj_simplex(v, z=1):
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if jax:
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)

if tf:
Expand Down
4 changes: 4 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def test_emd2_backends(nx):

valb = ot.emd2(ab, ab, Mb)

# check with empty inputs
valb2 = ot.emd2([], [], Mb)

np.allclose(val, nx.to_numpy(valb))
np.allclose(val, nx.to_numpy(valb2))


def test_emd_emd2_types_devices(nx):
Expand Down
12 changes: 12 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@ def test_cost_normalization(nx):
ot.utils.cost_normalization(C1, 'error')


def test_list_to_array(nx):

lst = [np.array([1, 2, 3]), np.array([4, 5, 6])]

a1, a2 = ot.utils.list_to_array(*lst)

assert a1.shape == (3,)
assert a2.shape == (3,)

a, b, M = ot.utils.list_to_array([], [], [[1.0, 2.0], [3.0, 4.0]])


def test_check_params():

res1 = ot.utils.check_params(first='OK', second=20)
Expand Down