Skip to content

Commit f4ca25a

Browse files
committed
Vectorize per-label computations for sinkhorn l1l2
1 parent 9dd4d8d commit f4ca25a

File tree

3 files changed

+57
-30
lines changed

3 files changed

+57
-30
lines changed

ot/backend.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def power(self, a, exponents):
407407
"""
408408
raise NotImplementedError()
409409

410-
def norm(self, a):
410+
def norm(self, a, axis=None):
411411
r"""
412412
Computes the matrix frobenius norm.
413413
@@ -627,7 +627,7 @@ def diag(self, a, k=0):
627627
"""
628628
raise NotImplementedError()
629629

630-
def unique(self, a):
630+
def unique(self, a, return_inverse=False):
631631
r"""
632632
Finds unique elements of given tensor.
633633
@@ -1087,8 +1087,8 @@ def sqrt(self, a):
10871087
def power(self, a, exponents):
10881088
return np.power(a, exponents)
10891089

1090-
def norm(self, a):
1091-
return np.sqrt(np.sum(np.square(a)))
1090+
def norm(self, a, axis=None):
1091+
return np.linalg.norm(a, axis=axis)
10921092

10931093
def any(self, a):
10941094
return np.any(a)
@@ -1164,8 +1164,8 @@ def meshgrid(self, a, b):
11641164
def diag(self, a, k=0):
11651165
return np.diag(a, k)
11661166

1167-
def unique(self, a):
1168-
return np.unique(a)
1167+
def unique(self, a, return_inverse=False):
1168+
return np.unique(a, return_inverse=return_inverse)
11691169

11701170
def logsumexp(self, a, axis=None):
11711171
return special.logsumexp(a, axis=axis)
@@ -1461,8 +1461,8 @@ def sqrt(self, a):
14611461
def power(self, a, exponents):
14621462
return jnp.power(a, exponents)
14631463

1464-
def norm(self, a):
1465-
return jnp.sqrt(jnp.sum(jnp.square(a)))
1464+
def norm(self, a, axis=None):
1465+
return jnp.linalg.norm(a, axis=axis)
14661466

14671467
def any(self, a):
14681468
return jnp.any(a)
@@ -1535,8 +1535,8 @@ def meshgrid(self, a, b):
15351535
def diag(self, a, k=0):
15361536
return jnp.diag(a, k)
15371537

1538-
def unique(self, a):
1539-
return jnp.unique(a)
1538+
def unique(self, a, return_inverse=False):
1539+
return jnp.unique(a, return_inverse=return_inverse)
15401540

15411541
def logsumexp(self, a, axis=None):
15421542
return jspecial.logsumexp(a, axis=axis)
@@ -1881,8 +1881,8 @@ def sqrt(self, a):
18811881
def power(self, a, exponents):
18821882
return torch.pow(a, exponents)
18831883

1884-
def norm(self, a):
1885-
return torch.sqrt(torch.sum(torch.square(a)))
1884+
def norm(self, a, axis=None):
1885+
return torch.linalg.norm(a, dim=axis)
18861886

18871887
def any(self, a):
18881888
return torch.any(a)
@@ -1986,8 +1986,8 @@ def meshgrid(self, a, b):
19861986
def diag(self, a, k=0):
19871987
return torch.diag(a, diagonal=k)
19881988

1989-
def unique(self, a):
1990-
return torch.unique(a)
1989+
def unique(self, a, return_inverse=False):
1990+
return torch.unique(a, return_inverse=return_inverse)
19911991

19921992
def logsumexp(self, a, axis=None):
19931993
if axis is not None:
@@ -2306,8 +2306,8 @@ def power(self, a, exponents):
23062306
def dot(self, a, b):
23072307
return cp.dot(a, b)
23082308

2309-
def norm(self, a):
2310-
return cp.sqrt(cp.sum(cp.square(a)))
2309+
def norm(self, a, axis=None):
2310+
return cp.linalg.norm(a, axis=axis)
23112311

23122312
def any(self, a):
23132313
return cp.any(a)
@@ -2383,8 +2383,8 @@ def meshgrid(self, a, b):
23832383
def diag(self, a, k=0):
23842384
return cp.diag(a, k)
23852385

2386-
def unique(self, a):
2387-
return cp.unique(a)
2386+
def unique(self, a, return_inverse=False):
2387+
return cp.unique(a, return_inverse=return_inverse)
23882388

23892389
def logsumexp(self, a, axis=None):
23902390
# Taken from
@@ -2717,8 +2717,8 @@ def sqrt(self, a):
27172717
def power(self, a, exponents):
27182718
return tnp.power(a, exponents)
27192719

2720-
def norm(self, a):
2721-
return tf.math.reduce_euclidean_norm(a)
2720+
def norm(self, a, axis=None):
2721+
return tf.math.reduce_euclidean_norm(a, axis=axis)
27222722

27232723
def any(self, a):
27242724
return tnp.any(a)
@@ -2790,8 +2790,10 @@ def meshgrid(self, a, b):
27902790
def diag(self, a, k=0):
27912791
return tnp.diag(a, k)
27922792

2793-
def unique(self, a):
2794-
return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
2793+
def unique(self, a, return_inverse=False):
2794+
y, idx = tf.unique(tf.reshape(a, [-1]))
2795+
sort_idx = tf.argsort(y)
2796+
return y[sort_idx] if not return_inverse else (y[sort_idx], idx[sort_idx])
27952797

27962798
def logsumexp(self, a, axis=None):
27972799
return tf.math.reduce_logsumexp(a, axis=axis)

ot/da.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,14 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
235235
a, labels_a, b, M = list_to_array(a, labels_a, b, M)
236236
nx = get_backend(a, labels_a, b, M)
237237

238-
lstlab = nx.unique(labels_a)
239-
238+
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
239+
n_labels = labels_u.shape[0]
240+
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]
240241
def f(G):
241-
res = 0
242-
for i in range(G.shape[1]):
243-
for lab in lstlab:
244-
temp = G[labels_a == lab, i]
245-
res += nx.norm(temp)
246-
return res
242+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
243+
return nx.norm(G_split * unroll_labels_idx, axis=1).sum()
247244

245+
lstlab = nx.unique(labels_a)
248246
def df(G):
249247
W = nx.zeros(G.shape, type_as=G)
250248
for i in range(G.shape[1]):

test/test_da.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,3 +800,30 @@ def test_emd_laplace_class(nx):
800800
transp_ys = otda.inverse_transform_labels(yt)
801801
assert_equal(transp_ys.shape[0], ys.shape[0])
802802
assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt))))
803+
804+
805+
def test_sinkhorn_l1l2_gl_cost_vectorized():
806+
n_samples, n_labels = 150, 3
807+
rng = np.random.RandomState(42)
808+
G = rng.rand(n_samples, n_samples)
809+
labels_a = rng.randint(n_labels, size=(n_samples,))
810+
811+
# previously used implementation for the cost estimator
812+
lstlab = np.unique(labels_a)
813+
def f(G):
814+
res = 0
815+
for i in range(G.shape[1]):
816+
for lab in lstlab:
817+
temp = G[labels_a == lab, i]
818+
res += np.linalg.norm(temp)
819+
return res
820+
821+
# new vectorized implementation for the cost estimator
822+
lstlab, lstlab_idx = np.unique(labels_a, return_inverse=True)
823+
n_samples = lstlab.shape[0]
824+
midx = np.eye(n_samples, dtype='int32')[None, lstlab_idx]
825+
def f2(G):
826+
G_split = np.repeat(G.T[:, :, None], n_samples, axis=2)
827+
return np.linalg.norm(G_split * midx, axis=1).sum()
828+
829+
assert np.allclose(f(G), f2(G))

0 commit comments

Comments
 (0)