Skip to content

Commit 4cf4492

Browse files
kachayevrflamary
andauthored
Vectorize cost and gradient for ot.da.sinkhorn_l1l2_gl (#507)
* Vectorize per-label computations for sinkhorn l1l2 * Fix flake * Convert to double befoe computing torch norm * Vectorize gradient for sinkhorn l1l2 * Skip TF and JAX backends (as it is done for other DA tests) * Additional tests for norm() and unique() * TF implementation of unique should inverse sort permutation for indicies * Fix sort indicies for unique(), switched to tf.gather * tf.math.invert_permutation * Mention changes in changelog --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 526b72f commit 4cf4492

File tree

5 files changed

+106
-39
lines changed

5 files changed

+106
-39
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#### New features
66
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
7+
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
78

89
#### Closed issues
910
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/backend.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def power(self, a, exponents):
411411
"""
412412
raise NotImplementedError()
413413

414-
def norm(self, a):
414+
def norm(self, a, axis=None, keepdims=False):
415415
r"""
416416
Computes the matrix frobenius norm.
417417
@@ -631,7 +631,7 @@ def diag(self, a, k=0):
631631
"""
632632
raise NotImplementedError()
633633

634-
def unique(self, a):
634+
def unique(self, a, return_inverse=False):
635635
r"""
636636
Finds unique elements of given tensor.
637637
@@ -1091,8 +1091,8 @@ def sqrt(self, a):
10911091
def power(self, a, exponents):
10921092
return np.power(a, exponents)
10931093

1094-
def norm(self, a):
1095-
return np.sqrt(np.sum(np.square(a)))
1094+
def norm(self, a, axis=None, keepdims=False):
1095+
return np.linalg.norm(a, axis=axis, keepdims=keepdims)
10961096

10971097
def any(self, a):
10981098
return np.any(a)
@@ -1168,8 +1168,8 @@ def meshgrid(self, a, b):
11681168
def diag(self, a, k=0):
11691169
return np.diag(a, k)
11701170

1171-
def unique(self, a):
1172-
return np.unique(a)
1171+
def unique(self, a, return_inverse=False):
1172+
return np.unique(a, return_inverse=return_inverse)
11731173

11741174
def logsumexp(self, a, axis=None):
11751175
return special.logsumexp(a, axis=axis)
@@ -1465,8 +1465,8 @@ def sqrt(self, a):
14651465
def power(self, a, exponents):
14661466
return jnp.power(a, exponents)
14671467

1468-
def norm(self, a):
1469-
return jnp.sqrt(jnp.sum(jnp.square(a)))
1468+
def norm(self, a, axis=None, keepdims=False):
1469+
return jnp.linalg.norm(a, axis=axis, keepdims=keepdims)
14701470

14711471
def any(self, a):
14721472
return jnp.any(a)
@@ -1539,8 +1539,8 @@ def meshgrid(self, a, b):
15391539
def diag(self, a, k=0):
15401540
return jnp.diag(a, k)
15411541

1542-
def unique(self, a):
1543-
return jnp.unique(a)
1542+
def unique(self, a, return_inverse=False):
1543+
return jnp.unique(a, return_inverse=return_inverse)
15441544

15451545
def logsumexp(self, a, axis=None):
15461546
return jspecial.logsumexp(a, axis=axis)
@@ -1885,8 +1885,8 @@ def sqrt(self, a):
18851885
def power(self, a, exponents):
18861886
return torch.pow(a, exponents)
18871887

1888-
def norm(self, a):
1889-
return torch.sqrt(torch.sum(torch.square(a)))
1888+
def norm(self, a, axis=None, keepdims=False):
1889+
return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims)
18901890

18911891
def any(self, a):
18921892
return torch.any(a)
@@ -1990,8 +1990,8 @@ def meshgrid(self, a, b):
19901990
def diag(self, a, k=0):
19911991
return torch.diag(a, diagonal=k)
19921992

1993-
def unique(self, a):
1994-
return torch.unique(a)
1993+
def unique(self, a, return_inverse=False):
1994+
return torch.unique(a, return_inverse=return_inverse)
19951995

19961996
def logsumexp(self, a, axis=None):
19971997
if axis is not None:
@@ -2310,8 +2310,8 @@ def power(self, a, exponents):
23102310
def dot(self, a, b):
23112311
return cp.dot(a, b)
23122312

2313-
def norm(self, a):
2314-
return cp.sqrt(cp.sum(cp.square(a)))
2313+
def norm(self, a, axis=None, keepdims=False):
2314+
return cp.linalg.norm(a, axis=axis, keepdims=keepdims)
23152315

23162316
def any(self, a):
23172317
return cp.any(a)
@@ -2387,8 +2387,8 @@ def meshgrid(self, a, b):
23872387
def diag(self, a, k=0):
23882388
return cp.diag(a, k)
23892389

2390-
def unique(self, a):
2391-
return cp.unique(a)
2390+
def unique(self, a, return_inverse=False):
2391+
return cp.unique(a, return_inverse=return_inverse)
23922392

23932393
def logsumexp(self, a, axis=None):
23942394
# Taken from
@@ -2721,8 +2721,8 @@ def sqrt(self, a):
27212721
def power(self, a, exponents):
27222722
return tnp.power(a, exponents)
27232723

2724-
def norm(self, a):
2725-
return tf.math.reduce_euclidean_norm(a)
2724+
def norm(self, a, axis=None, keepdims=False):
2725+
return tf.math.reduce_euclidean_norm(a, axis=axis, keepdims=keepdims)
27262726

27272727
def any(self, a):
27282728
return tnp.any(a)
@@ -2794,8 +2794,15 @@ def meshgrid(self, a, b):
27942794
def diag(self, a, k=0):
27952795
return tnp.diag(a, k)
27962796

2797-
def unique(self, a):
2798-
return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
2797+
def unique(self, a, return_inverse=False):
2798+
y, idx = tf.unique(tf.reshape(a, [-1]))
2799+
sort_idx = tf.argsort(y)
2800+
y_prime = tf.gather(y, sort_idx)
2801+
if return_inverse:
2802+
inv_sort_idx = tf.math.invert_permutation(sort_idx)
2803+
return y_prime, tf.gather(inv_sort_idx, idx)
2804+
else:
2805+
return y_prime
27992806

28002807
def logsumexp(self, a, axis=None):
28012808
return tf.math.reduce_logsumexp(a, axis=axis)

ot/da.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
148148

149149

150150
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
151-
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
152-
log=False):
151+
numInnerItermax=200, stopInnerThr=1e-9, eps=1e-12,
152+
verbose=False, log=False):
153153
r"""
154154
Solve the entropic regularization optimal transport problem with group
155155
lasso regularization
@@ -202,6 +202,8 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
202202
Max number of iterations (inner sinkhorn solver)
203203
stopInnerThr : float, optional
204204
Stop threshold on error (inner sinkhorn solver) (>0)
205+
eps: float, optional (default=1e-12)
206+
Small value to avoid division by zero
205207
verbose : bool, optional
206208
Print information along iterations
207209
log : bool, optional
@@ -235,25 +237,19 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
235237
a, labels_a, b, M = list_to_array(a, labels_a, b, M)
236238
nx = get_backend(a, labels_a, b, M)
237239

238-
lstlab = nx.unique(labels_a)
240+
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
241+
n_labels = labels_u.shape[0]
242+
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]
239243

240244
def f(G):
241-
res = 0
242-
for i in range(G.shape[1]):
243-
for lab in lstlab:
244-
temp = G[labels_a == lab, i]
245-
res += nx.norm(temp)
246-
return res
245+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
246+
return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1))
247247

248248
def df(G):
249-
W = nx.zeros(G.shape, type_as=G)
250-
for i in range(G.shape[1]):
251-
for lab in lstlab:
252-
temp = G[labels_a == lab, i]
253-
n = nx.norm(temp)
254-
if n:
255-
W[labels_a == lab, i] = temp / n
256-
return W
249+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx
250+
W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True)
251+
G_norm = G_split / nx.clip(W, eps, None)
252+
return nx.sum(G_norm, axis=2).T
257253

258254
return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
259255
numInnerItermax=numInnerItermax, stopThr=stopInnerThr,

test/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,14 @@ def test_func_backends(nx):
412412
lst_b.append(nx.to_numpy(A))
413413
lst_name.append('norm')
414414

415+
A = nx.norm(Mb, axis=1)
416+
lst_b.append(nx.to_numpy(A))
417+
lst_name.append('norm(M,axis=1)')
418+
419+
A = nx.norm(Mb, axis=1, keepdims=True)
420+
lst_b.append(nx.to_numpy(A))
421+
lst_name.append('norm(M,axis=1,keepdims=True)')
422+
415423
A = nx.any(vb > 0)
416424
lst_b.append(nx.to_numpy(A))
417425
lst_name.append('any')
@@ -518,6 +526,12 @@ def test_func_backends(nx):
518526
lst_b.append(nx.to_numpy(A))
519527
lst_name.append('unique')
520528

529+
A, A2 = nx.unique(nx.from_numpy(np.stack([M, M]).reshape(-1)), return_inverse=True)
530+
lst_b.append(nx.to_numpy(A))
531+
lst_name.append('unique(M,return_inverse=True)[0]')
532+
lst_b.append(nx.to_numpy(A2))
533+
lst_name.append('unique(M,return_inverse=True)[1]')
534+
521535
A = nx.logsumexp(Mb)
522536
lst_b.append(nx.to_numpy(A))
523537
lst_name.append('logsumexp')

test/test_da.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,52 @@ def test_emd_laplace_class(nx):
801801
transp_ys = otda.inverse_transform_labels(yt)
802802
assert_equal(transp_ys.shape[0], ys.shape[0])
803803
assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt))))
804+
805+
806+
@pytest.skip_backend("jax")
807+
@pytest.skip_backend("tf")
808+
def test_sinkhorn_l1l2_gl_cost_vectorized(nx):
809+
n_samples, n_labels = 150, 3
810+
rng = np.random.RandomState(42)
811+
G = rng.rand(n_samples, n_samples)
812+
labels_a = rng.randint(n_labels, size=(n_samples,))
813+
G, labels_a = nx.from_numpy(G), nx.from_numpy(labels_a)
814+
815+
# previously used implementation for the cost estimator
816+
lstlab = nx.unique(labels_a)
817+
818+
def f(G):
819+
res = 0
820+
for i in range(G.shape[1]):
821+
for lab in lstlab:
822+
temp = G[labels_a == lab, i]
823+
res += nx.norm(temp)
824+
return res
825+
826+
def df(G):
827+
W = nx.zeros(G.shape, type_as=G)
828+
for i in range(G.shape[1]):
829+
for lab in lstlab:
830+
temp = G[labels_a == lab, i]
831+
n = nx.norm(temp)
832+
if n:
833+
W[labels_a == lab, i] = temp / n
834+
return W
835+
836+
# new vectorized implementation for the cost estimator
837+
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
838+
n_labels = labels_u.shape[0]
839+
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]
840+
841+
def f2(G):
842+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
843+
return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1))
844+
845+
def df2(G):
846+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx
847+
W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True)
848+
G_norm = G_split / nx.clip(W, 1e-12, None)
849+
return nx.sum(G_norm, axis=2).T
850+
851+
assert np.allclose(f(G), f2(G))
852+
assert np.allclose(df(G), df2(G))

0 commit comments

Comments
 (0)