Skip to content

Commit 3eda17e

Browse files
committed
hopefully final fix
1 parent 5b8e170 commit 3eda17e

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

ot/lp/solver_1d.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
223223
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
224224
transportation matrix)
225225
"""
226-
a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
226+
x_a, x_b = list_to_array(x_a, x_b)
227227
nx = get_backend(x_a, x_b)
228+
if a is not None:
229+
a = list_to_array(a, nx=nx)
230+
if b is not None:
231+
b = list_to_array(b, nx=nx)
228232

229233
assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
230234
"emd_1d should only be used with monodimensional data"

ot/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,20 @@ def laplacian(x):
5858

5959
def list_to_array(*lst, nx=None):
6060
r""" Convert a list if in numpy format """
61+
lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)]
6162
if nx is None: # find backend
62-
lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)]
63+
6364
if len(lst_not_empty) == 0:
6465
type_as = np.zeros(0)
6566
nx = get_backend(type_as)
6667
else:
6768
nx = get_backend(*lst_not_empty)
6869
type_as = lst_not_empty[0]
70+
else:
71+
if len(lst_not_empty) == 0:
72+
type_as = None
73+
else:
74+
type_as = lst_not_empty[0]
6975
if len(lst) > 1:
7076
return [nx.from_numpy(np.array(a), type_as=type_as)
7177
if isinstance(a, list) else a for a in lst]

0 commit comments

Comments
 (0)