Skip to content

Commit a016406

Browse files
authored
Merge branch 'master' into fix-backend-allocations
2 parents 26cb118 + 4cf4492 commit a016406

File tree

5 files changed

+106
-39
lines changed

5 files changed

+106
-39
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#### New features
66
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
7+
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
78

89
#### Closed issues
910
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/backend.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def power(self, a, exponents):
460460
"""
461461
raise NotImplementedError()
462462

463-
def norm(self, a):
463+
def norm(self, a, axis=None, keepdims=False):
464464
r"""
465465
Computes the matrix frobenius norm.
466466
@@ -680,7 +680,7 @@ def diag(self, a, k=0):
680680
"""
681681
raise NotImplementedError()
682682

683-
def unique(self, a):
683+
def unique(self, a, return_inverse=False):
684684
r"""
685685
Finds unique elements of given tensor.
686686
@@ -1140,8 +1140,8 @@ def sqrt(self, a):
11401140
def power(self, a, exponents):
11411141
return np.power(a, exponents)
11421142

1143-
def norm(self, a):
1144-
return np.sqrt(np.sum(np.square(a)))
1143+
def norm(self, a, axis=None, keepdims=False):
1144+
return np.linalg.norm(a, axis=axis, keepdims=keepdims)
11451145

11461146
def any(self, a):
11471147
return np.any(a)
@@ -1217,8 +1217,8 @@ def meshgrid(self, a, b):
12171217
def diag(self, a, k=0):
12181218
return np.diag(a, k)
12191219

1220-
def unique(self, a):
1221-
return np.unique(a)
1220+
def unique(self, a, return_inverse=False):
1221+
return np.unique(a, return_inverse=return_inverse)
12221222

12231223
def logsumexp(self, a, axis=None):
12241224
return special.logsumexp(a, axis=axis)
@@ -1514,8 +1514,8 @@ def sqrt(self, a):
15141514
def power(self, a, exponents):
15151515
return jnp.power(a, exponents)
15161516

1517-
def norm(self, a):
1518-
return jnp.sqrt(jnp.sum(jnp.square(a)))
1517+
def norm(self, a, axis=None, keepdims=False):
1518+
return jnp.linalg.norm(a, axis=axis, keepdims=keepdims)
15191519

15201520
def any(self, a):
15211521
return jnp.any(a)
@@ -1588,8 +1588,8 @@ def meshgrid(self, a, b):
15881588
def diag(self, a, k=0):
15891589
return jnp.diag(a, k)
15901590

1591-
def unique(self, a):
1592-
return jnp.unique(a)
1591+
def unique(self, a, return_inverse=False):
1592+
return jnp.unique(a, return_inverse=return_inverse)
15931593

15941594
def logsumexp(self, a, axis=None):
15951595
return jspecial.logsumexp(a, axis=axis)
@@ -1934,8 +1934,8 @@ def sqrt(self, a):
19341934
def power(self, a, exponents):
19351935
return torch.pow(a, exponents)
19361936

1937-
def norm(self, a):
1938-
return torch.sqrt(torch.sum(torch.square(a)))
1937+
def norm(self, a, axis=None, keepdims=False):
1938+
return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims)
19391939

19401940
def any(self, a):
19411941
return torch.any(a)
@@ -2039,8 +2039,8 @@ def meshgrid(self, a, b):
20392039
def diag(self, a, k=0):
20402040
return torch.diag(a, diagonal=k)
20412041

2042-
def unique(self, a):
2043-
return torch.unique(a)
2042+
def unique(self, a, return_inverse=False):
2043+
return torch.unique(a, return_inverse=return_inverse)
20442044

20452045
def logsumexp(self, a, axis=None):
20462046
if axis is not None:
@@ -2359,8 +2359,8 @@ def power(self, a, exponents):
23592359
def dot(self, a, b):
23602360
return cp.dot(a, b)
23612361

2362-
def norm(self, a):
2363-
return cp.sqrt(cp.sum(cp.square(a)))
2362+
def norm(self, a, axis=None, keepdims=False):
2363+
return cp.linalg.norm(a, axis=axis, keepdims=keepdims)
23642364

23652365
def any(self, a):
23662366
return cp.any(a)
@@ -2436,8 +2436,8 @@ def meshgrid(self, a, b):
24362436
def diag(self, a, k=0):
24372437
return cp.diag(a, k)
24382438

2439-
def unique(self, a):
2440-
return cp.unique(a)
2439+
def unique(self, a, return_inverse=False):
2440+
return cp.unique(a, return_inverse=return_inverse)
24412441

24422442
def logsumexp(self, a, axis=None):
24432443
# Taken from
@@ -2770,8 +2770,8 @@ def sqrt(self, a):
27702770
def power(self, a, exponents):
27712771
return tnp.power(a, exponents)
27722772

2773-
def norm(self, a):
2774-
return tf.math.reduce_euclidean_norm(a)
2773+
def norm(self, a, axis=None, keepdims=False):
2774+
return tf.math.reduce_euclidean_norm(a, axis=axis, keepdims=keepdims)
27752775

27762776
def any(self, a):
27772777
return tnp.any(a)
@@ -2843,8 +2843,15 @@ def meshgrid(self, a, b):
28432843
def diag(self, a, k=0):
28442844
return tnp.diag(a, k)
28452845

2846-
def unique(self, a):
2847-
return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
2846+
def unique(self, a, return_inverse=False):
2847+
y, idx = tf.unique(tf.reshape(a, [-1]))
2848+
sort_idx = tf.argsort(y)
2849+
y_prime = tf.gather(y, sort_idx)
2850+
if return_inverse:
2851+
inv_sort_idx = tf.math.invert_permutation(sort_idx)
2852+
return y_prime, tf.gather(inv_sort_idx, idx)
2853+
else:
2854+
return y_prime
28482855

28492856
def logsumexp(self, a, axis=None):
28502857
return tf.math.reduce_logsumexp(a, axis=axis)

ot/da.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
148148

149149

150150
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
151-
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
152-
log=False):
151+
numInnerItermax=200, stopInnerThr=1e-9, eps=1e-12,
152+
verbose=False, log=False):
153153
r"""
154154
Solve the entropic regularization optimal transport problem with group
155155
lasso regularization
@@ -202,6 +202,8 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
202202
Max number of iterations (inner sinkhorn solver)
203203
stopInnerThr : float, optional
204204
Stop threshold on error (inner sinkhorn solver) (>0)
205+
eps: float, optional (default=1e-12)
206+
Small value to avoid division by zero
205207
verbose : bool, optional
206208
Print information along iterations
207209
log : bool, optional
@@ -235,25 +237,19 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
235237
a, labels_a, b, M = list_to_array(a, labels_a, b, M)
236238
nx = get_backend(a, labels_a, b, M)
237239

238-
lstlab = nx.unique(labels_a)
240+
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
241+
n_labels = labels_u.shape[0]
242+
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]
239243

240244
def f(G):
241-
res = 0
242-
for i in range(G.shape[1]):
243-
for lab in lstlab:
244-
temp = G[labels_a == lab, i]
245-
res += nx.norm(temp)
246-
return res
245+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
246+
return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1))
247247

248248
def df(G):
249-
W = nx.zeros(G.shape, type_as=G)
250-
for i in range(G.shape[1]):
251-
for lab in lstlab:
252-
temp = G[labels_a == lab, i]
253-
n = nx.norm(temp)
254-
if n:
255-
W[labels_a == lab, i] = temp / n
256-
return W
249+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx
250+
W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True)
251+
G_norm = G_split / nx.clip(W, eps, None)
252+
return nx.sum(G_norm, axis=2).T
257253

258254
return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
259255
numInnerItermax=numInnerItermax, stopThr=stopInnerThr,

test/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,14 @@ def test_func_backends(nx):
412412
lst_b.append(nx.to_numpy(A))
413413
lst_name.append('norm')
414414

415+
A = nx.norm(Mb, axis=1)
416+
lst_b.append(nx.to_numpy(A))
417+
lst_name.append('norm(M,axis=1)')
418+
419+
A = nx.norm(Mb, axis=1, keepdims=True)
420+
lst_b.append(nx.to_numpy(A))
421+
lst_name.append('norm(M,axis=1,keepdims=True)')
422+
415423
A = nx.any(vb > 0)
416424
lst_b.append(nx.to_numpy(A))
417425
lst_name.append('any')
@@ -518,6 +526,12 @@ def test_func_backends(nx):
518526
lst_b.append(nx.to_numpy(A))
519527
lst_name.append('unique')
520528

529+
A, A2 = nx.unique(nx.from_numpy(np.stack([M, M]).reshape(-1)), return_inverse=True)
530+
lst_b.append(nx.to_numpy(A))
531+
lst_name.append('unique(M,return_inverse=True)[0]')
532+
lst_b.append(nx.to_numpy(A2))
533+
lst_name.append('unique(M,return_inverse=True)[1]')
534+
521535
A = nx.logsumexp(Mb)
522536
lst_b.append(nx.to_numpy(A))
523537
lst_name.append('logsumexp')

test/test_da.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,52 @@ def test_emd_laplace_class(nx):
801801
transp_ys = otda.inverse_transform_labels(yt)
802802
assert_equal(transp_ys.shape[0], ys.shape[0])
803803
assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt))))
804+
805+
806+
@pytest.skip_backend("jax")
807+
@pytest.skip_backend("tf")
808+
def test_sinkhorn_l1l2_gl_cost_vectorized(nx):
809+
n_samples, n_labels = 150, 3
810+
rng = np.random.RandomState(42)
811+
G = rng.rand(n_samples, n_samples)
812+
labels_a = rng.randint(n_labels, size=(n_samples,))
813+
G, labels_a = nx.from_numpy(G), nx.from_numpy(labels_a)
814+
815+
# previously used implementation for the cost estimator
816+
lstlab = nx.unique(labels_a)
817+
818+
def f(G):
819+
res = 0
820+
for i in range(G.shape[1]):
821+
for lab in lstlab:
822+
temp = G[labels_a == lab, i]
823+
res += nx.norm(temp)
824+
return res
825+
826+
def df(G):
827+
W = nx.zeros(G.shape, type_as=G)
828+
for i in range(G.shape[1]):
829+
for lab in lstlab:
830+
temp = G[labels_a == lab, i]
831+
n = nx.norm(temp)
832+
if n:
833+
W[labels_a == lab, i] = temp / n
834+
return W
835+
836+
# new vectorized implementation for the cost estimator
837+
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
838+
n_labels = labels_u.shape[0]
839+
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]
840+
841+
def f2(G):
842+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
843+
return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1))
844+
845+
def df2(G):
846+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx
847+
W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True)
848+
G_norm = G_split / nx.clip(W, 1e-12, None)
849+
return nx.sum(G_norm, axis=2).T
850+
851+
assert np.allclose(f(G), f2(G))
852+
assert np.allclose(df(G), df2(G))

0 commit comments

Comments
 (0)