From 058fd232a4014a3d45ffca3d87ffb16b4f851b82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 29 Feb 2024 12:19:16 +0100 Subject: [PATCH 1/5] fix buf emd2 for empty inputs --- ot/lp/__init__.py | 50 +++++++++++++++++++++++++++------------------- ot/utils.py | 18 ++++++++++++++--- test/test_ot.py | 3 +++ test/test_utils.py | 12 +++++++++++ 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 545d1d8cd..93316a6c1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -302,17 +302,24 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c ot.optim.cg : General regularized OT """ - # convert to numpy if list a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) - a0, b0, M0 = a, b, M - if len(a0) != 0: - type_as = a0 - elif len(b0) != 0: - type_as = b0 + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b else: - type_as = M0 - nx = get_backend(M0, a0, b0) + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M # convert to numpy M, a, b = nx.to_numpy(M, a, b) @@ -474,15 +481,23 @@ def emd2(a, b, M, processes=1, """ a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) - a0, b0, M0 = a, b, M - if len(a0) != 0: - type_as = a0 - elif len(b0) != 0: - type_as = b0 + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b else: - type_as = M0 - nx = get_backend(M0, a0, b0) + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M # convert to numpy M, a, b = nx.to_numpy(M, a, b) @@ -491,11 +506,6 @@ def emd2(a, b, M, processes=1, b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64, order='C') - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" diff --git a/ot/utils.py b/ot/utils.py index 19e61f1fe..e0b6b51a7 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -56,12 +56,24 @@ def laplacian(x): return L -def list_to_array(*lst): +def list_to_array(*lst, nx=None): r""" Convert a list if in numpy format """ + if nx is None: # find backend + lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)] + if len(lst_not_empty) == 0: + type_as = np.zeros(0) + nx = get_backend(type_as) + else: + nx = get_backend(*lst_not_empty) + type_as = lst_not_empty[0] if len(lst) > 1: - return [np.array(a) if isinstance(a, list) else a for a in lst] + return [nx.from_numpy(np.array(a), type_as=type_as) + if isinstance(a, list) else a for a in lst] else: - return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + if isinstance(lst[0], list): + return nx.from_numpy(np.array(lst[0]), type_as=type_as) + else: + return lst[0] def proj_simplex(v, z=1): diff --git a/test/test_ot.py b/test/test_ot.py index 5c6e6732b..91513470f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -74,6 +74,9 @@ def test_emd2_backends(nx): valb = ot.emd2(ab, ab, Mb) + # check with empty inputs + valb2 = ot.emd2([], [], Mb) + np.allclose(val, nx.to_numpy(valb)) diff --git a/test/test_utils.py b/test/test_utils.py index 6cdb7ead7..966cef989 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -322,6 +322,18 @@ def test_cost_normalization(nx): ot.utils.cost_normalization(C1, 'error') +def test_list_to_array(nx): + + lst = [np.array([1, 2, 3]), np.array([4, 5, 6])] + + a1, a2 = ot.utils.list_to_array(*lst) + + assert a1.shape == (3,) + assert a2.shape == (3,) + + a, b, M = ot.utils.list_to_array([], [], [[1.0, 2.0], [3.0, 4.0]]) + + def test_check_params(): res1 = ot.utils.check_params(first='OK', second=20) From 09f4a74416f455c9f556321a9b881a28d52c5007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 29 Feb 2024 12:33:01 +0100 Subject: [PATCH 2/5] update release file --- RELEASES.md | 3 ++- ot/__init__.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 998d56836..cc83f2530 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,11 +1,12 @@ # Releases -## 0.9.3 +## 0.9.3dev #### Closed issues - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) +- Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) ## 0.9.2 *December 2023* diff --git a/ot/__init__.py b/ot/__init__.py index db49d6c34..9a63b5f6f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -58,7 +58,7 @@ # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.9.3" +__version__ = "0.9.3dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', From 0c3e412a91b898f0e151fb109831c738a8633741 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 29 Feb 2024 12:55:48 +0100 Subject: [PATCH 3/5] debug problems in optimization hen using list_to_arry by removing it everywhere --- ot/gromov/_gw.py | 2 -- ot/gromov/_semirelaxed.py | 2 -- ot/optim.py | 4 +--- test/test_ot.py | 1 + 4 files changed, 2 insertions(+), 7 deletions(-) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 69dd3df0c..ea0be1833 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -699,8 +699,6 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, """ if nx is None: - G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) - if isinstance(M, int) or isinstance(M, float): nx = get_backend(G, deltaG, C1, C2) else: diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index fb9d2b3ca..c37ba2bf4 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -583,8 +583,6 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, Gromov-Wasserstein". NeurIPS 2023 Workshop OTML. """ if nx is None: - G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) - if isinstance(M, int) or isinstance(M, float): nx = get_backend(G, deltaG, C1, C2) else: diff --git a/ot/optim.py b/ot/optim.py index 8700f75d1..dcdef6a88 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,7 +12,6 @@ import warnings from .lp import emd from .bregman import sinkhorn -from .utils import list_to_array from .backend import get_backend with warnings.catch_warnings(): @@ -73,7 +72,6 @@ def line_search_armijo( """ if nx is None: - xk, pk, gfk = list_to_array(xk, pk, gfk) xk0, pk0 = xk, pk nx = get_backend(xk0, pk0) else: @@ -236,7 +234,7 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea ot.lp.emd : Unregularized optimal transport ot.bregman.sinkhorn : Entropic regularized optimal transport """ - a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): nx = get_backend(a, b) else: diff --git a/test/test_ot.py b/test/test_ot.py index 91513470f..a90321d5f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -78,6 +78,7 @@ def test_emd2_backends(nx): valb2 = ot.emd2([], [], Mb) np.allclose(val, nx.to_numpy(valb)) + np.allclose(val, nx.to_numpy(valb2)) def test_emd_emd2_types_devices(nx): From 46ab667b553a0c3f9252b12462344524195be544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 29 Feb 2024 13:40:15 +0100 Subject: [PATCH 4/5] update jax config in tests --- test/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/conftest.py b/test/conftest.py index 0303ed9f2..043c8ca70 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -13,7 +13,7 @@ if jax: os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' - from jax.config import config + from jax import config config.update("jax_enable_x64", True) if tf: From 3eda17e616d04fb68558a47c8cb65853263db061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 1 Mar 2024 08:31:10 +0100 Subject: [PATCH 5/5] hopefully final fix --- ot/lp/solver_1d.py | 6 +++++- ot/utils.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index e792db904..d9395c8d4 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -223,8 +223,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ - a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + x_a, x_b = list_to_array(x_a, x_b) nx = get_backend(x_a, x_b) + if a is not None: + a = list_to_array(a, nx=nx) + if b is not None: + b = list_to_array(b, nx=nx) assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ "emd_1d should only be used with monodimensional data" diff --git a/ot/utils.py b/ot/utils.py index e0b6b51a7..404a9f2db 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -58,14 +58,20 @@ def laplacian(x): def list_to_array(*lst, nx=None): r""" Convert a list if in numpy format """ + lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)] if nx is None: # find backend - lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)] + if len(lst_not_empty) == 0: type_as = np.zeros(0) nx = get_backend(type_as) else: nx = get_backend(*lst_not_empty) type_as = lst_not_empty[0] + else: + if len(lst_not_empty) == 0: + type_as = None + else: + type_as = lst_not_empty[0] if len(lst) > 1: return [nx.from_numpy(np.array(a), type_as=type_as) if isinstance(a, list) else a for a in lst]