Skip to content

Commit 0573eba

Browse files
authored
[MRG] Fix bug in emd2 with empty weighs on backends (#606)
* fix buf emd2 for empty inputs * update release file * debug problems in optimization hen using list_to_arry by removing it everywhere * update jax config in tests * hopefully final fix
1 parent f1fe593 commit 0573eba

File tree

11 files changed

+77
-34
lines changed

11 files changed

+77
-34
lines changed

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Releases
22

3-
## 0.9.3
3+
## 0.9.3dev
44

55
#### New features
66
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster.
@@ -9,6 +9,7 @@
99
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
1010
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
1111
- Fix doc and example for lowrank sinkhorn (PR #601)
12+
- Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534)
1213

1314
## 0.9.2
1415
*December 2023*

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
# utils functions
5959
from .utils import dist, unif, tic, toc, toq
6060

61-
__version__ = "0.9.3"
61+
__version__ = "0.9.3dev"
6262

6363
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
6464
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',

ot/gromov/_gw.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,6 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
703703
704704
"""
705705
if nx is None:
706-
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
707-
708706
if isinstance(M, int) or isinstance(M, float):
709707
nx = get_backend(G, deltaG, C1, C2)
710708
else:

ot/gromov/_semirelaxed.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,6 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
583583
Gromov-Wasserstein". NeurIPS 2023 Workshop OTML.
584584
"""
585585
if nx is None:
586-
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
587-
588586
if isinstance(M, int) or isinstance(M, float):
589587
nx = get_backend(G, deltaG, C1, C2)
590588
else:

ot/lp/__init__.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -302,17 +302,24 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c
302302
ot.optim.cg : General regularized OT
303303
"""
304304

305-
# convert to numpy if list
306305
a, b, M = list_to_array(a, b, M)
306+
nx = get_backend(M, a, b)
307307

308-
a0, b0, M0 = a, b, M
309-
if len(a0) != 0:
310-
type_as = a0
311-
elif len(b0) != 0:
312-
type_as = b0
308+
if len(a) != 0:
309+
type_as = a
310+
elif len(b) != 0:
311+
type_as = b
313312
else:
314-
type_as = M0
315-
nx = get_backend(M0, a0, b0)
313+
type_as = M
314+
315+
# if empty array given then use uniform distributions
316+
if len(a) == 0:
317+
a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
318+
if len(b) == 0:
319+
b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
320+
321+
# store original tensors
322+
a0, b0, M0 = a, b, M
316323

317324
# convert to numpy
318325
M, a, b = nx.to_numpy(M, a, b)
@@ -474,15 +481,23 @@ def emd2(a, b, M, processes=1,
474481
"""
475482

476483
a, b, M = list_to_array(a, b, M)
484+
nx = get_backend(M, a, b)
477485

478-
a0, b0, M0 = a, b, M
479-
if len(a0) != 0:
480-
type_as = a0
481-
elif len(b0) != 0:
482-
type_as = b0
486+
if len(a) != 0:
487+
type_as = a
488+
elif len(b) != 0:
489+
type_as = b
483490
else:
484-
type_as = M0
485-
nx = get_backend(M0, a0, b0)
491+
type_as = M
492+
493+
# if empty array given then use uniform distributions
494+
if len(a) == 0:
495+
a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
496+
if len(b) == 0:
497+
b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
498+
499+
# store original tensors
500+
a0, b0, M0 = a, b, M
486501

487502
# convert to numpy
488503
M, a, b = nx.to_numpy(M, a, b)
@@ -491,11 +506,6 @@ def emd2(a, b, M, processes=1,
491506
b = np.asarray(b, dtype=np.float64)
492507
M = np.asarray(M, dtype=np.float64, order='C')
493508

494-
# if empty array given then use uniform distributions
495-
if len(a) == 0:
496-
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
497-
if len(b) == 0:
498-
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
499509

500510
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
501511
"Dimension mismatch, check dimensions of M with a and b"

ot/lp/solver_1d.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
223223
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
224224
transportation matrix)
225225
"""
226-
a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
226+
x_a, x_b = list_to_array(x_a, x_b)
227227
nx = get_backend(x_a, x_b)
228+
if a is not None:
229+
a = list_to_array(a, nx=nx)
230+
if b is not None:
231+
b = list_to_array(b, nx=nx)
228232

229233
assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
230234
"emd_1d should only be used with monodimensional data"

ot/optim.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import warnings
1313
from .lp import emd
1414
from .bregman import sinkhorn
15-
from .utils import list_to_array
1615
from .backend import get_backend
1716

1817
with warnings.catch_warnings():
@@ -73,7 +72,6 @@ def line_search_armijo(
7372
7473
"""
7574
if nx is None:
76-
xk, pk, gfk = list_to_array(xk, pk, gfk)
7775
xk0, pk0 = xk, pk
7876
nx = get_backend(xk0, pk0)
7977
else:
@@ -236,7 +234,7 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea
236234
ot.lp.emd : Unregularized optimal transport
237235
ot.bregman.sinkhorn : Entropic regularized optimal transport
238236
"""
239-
a, b, M, G0 = list_to_array(a, b, M, G0)
237+
240238
if isinstance(M, int) or isinstance(M, float):
241239
nx = get_backend(a, b)
242240
else:

ot/utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,30 @@ def laplacian(x):
5656
return L
5757

5858

59-
def list_to_array(*lst):
59+
def list_to_array(*lst, nx=None):
6060
r""" Convert a list if in numpy format """
61+
lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)]
62+
if nx is None: # find backend
63+
64+
if len(lst_not_empty) == 0:
65+
type_as = np.zeros(0)
66+
nx = get_backend(type_as)
67+
else:
68+
nx = get_backend(*lst_not_empty)
69+
type_as = lst_not_empty[0]
70+
else:
71+
if len(lst_not_empty) == 0:
72+
type_as = None
73+
else:
74+
type_as = lst_not_empty[0]
6175
if len(lst) > 1:
62-
return [np.array(a) if isinstance(a, list) else a for a in lst]
76+
return [nx.from_numpy(np.array(a), type_as=type_as)
77+
if isinstance(a, list) else a for a in lst]
6378
else:
64-
return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
79+
if isinstance(lst[0], list):
80+
return nx.from_numpy(np.array(lst[0]), type_as=type_as)
81+
else:
82+
return lst[0]
6583

6684

6785
def proj_simplex(v, z=1):

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
if jax:
1515
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
16-
from jax.config import config
16+
from jax import config
1717
config.update("jax_enable_x64", True)
1818

1919
if tf:

test/test_ot.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ def test_emd2_backends(nx):
7474

7575
valb = ot.emd2(ab, ab, Mb)
7676

77+
# check with empty inputs
78+
valb2 = ot.emd2([], [], Mb)
79+
7780
np.allclose(val, nx.to_numpy(valb))
81+
np.allclose(val, nx.to_numpy(valb2))
7882

7983

8084
def test_emd_emd2_types_devices(nx):

test/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,18 @@ def test_cost_normalization(nx):
322322
ot.utils.cost_normalization(C1, 'error')
323323

324324

325+
def test_list_to_array(nx):
326+
327+
lst = [np.array([1, 2, 3]), np.array([4, 5, 6])]
328+
329+
a1, a2 = ot.utils.list_to_array(*lst)
330+
331+
assert a1.shape == (3,)
332+
assert a2.shape == (3,)
333+
334+
a, b, M = ot.utils.list_to_array([], [], [[1.0, 2.0], [3.0, 4.0]])
335+
336+
325337
def test_check_params():
326338

327339
res1 = ot.utils.check_params(first='OK', second=20)

0 commit comments

Comments
 (0)