Skip to content

Commit eb0ca02

Browse files
committed
Vectorize gradient for sinkhorn l1l2
1 parent 6dd9877 commit eb0ca02

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

ot/backend.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def power(self, a, exponents):
407407
"""
408408
raise NotImplementedError()
409409

410-
def norm(self, a, axis=None):
410+
def norm(self, a, axis=None, keepdims=False):
411411
r"""
412412
Computes the matrix frobenius norm.
413413
@@ -1087,8 +1087,8 @@ def sqrt(self, a):
10871087
def power(self, a, exponents):
10881088
return np.power(a, exponents)
10891089

1090-
def norm(self, a, axis=None):
1091-
return np.linalg.norm(a, axis=axis)
1090+
def norm(self, a, axis=None, keepdims=False):
1091+
return np.linalg.norm(a, axis=axis, keepdims=keepdims)
10921092

10931093
def any(self, a):
10941094
return np.any(a)
@@ -1461,8 +1461,8 @@ def sqrt(self, a):
14611461
def power(self, a, exponents):
14621462
return jnp.power(a, exponents)
14631463

1464-
def norm(self, a, axis=None):
1465-
return jnp.linalg.norm(a, axis=axis)
1464+
def norm(self, a, axis=None, keepdims=False):
1465+
return jnp.linalg.norm(a, axis=axis, keepdims=keepdims)
14661466

14671467
def any(self, a):
14681468
return jnp.any(a)
@@ -1881,8 +1881,8 @@ def sqrt(self, a):
18811881
def power(self, a, exponents):
18821882
return torch.pow(a, exponents)
18831883

1884-
def norm(self, a, axis=None):
1885-
return torch.linalg.norm(a.double(), dim=axis)
1884+
def norm(self, a, axis=None, keepdims=False):
1885+
return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims)
18861886

18871887
def any(self, a):
18881888
return torch.any(a)
@@ -2306,8 +2306,8 @@ def power(self, a, exponents):
23062306
def dot(self, a, b):
23072307
return cp.dot(a, b)
23082308

2309-
def norm(self, a, axis=None):
2310-
return cp.linalg.norm(a, axis=axis)
2309+
def norm(self, a, axis=None, keepdims=False):
2310+
return cp.linalg.norm(a, axis=axis, keepdims=keepdims)
23112311

23122312
def any(self, a):
23132313
return cp.any(a)
@@ -2717,8 +2717,8 @@ def sqrt(self, a):
27172717
def power(self, a, exponents):
27182718
return tnp.power(a, exponents)
27192719

2720-
def norm(self, a, axis=None):
2721-
return tf.math.reduce_euclidean_norm(a, axis=axis)
2720+
def norm(self, a, axis=None, keepdims=False):
2721+
return tf.math.reduce_euclidean_norm(a, axis=axis, keepdims=keepdims)
27222722

27232723
def any(self, a):
27242724
return tnp.any(a)

ot/da.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
148148

149149

150150
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
151-
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
152-
log=False):
151+
numInnerItermax=200, stopInnerThr=1e-9, eps=1e-12,
152+
verbose=False, log=False):
153153
r"""
154154
Solve the entropic regularization optimal transport problem with group
155155
lasso regularization
@@ -202,6 +202,8 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
202202
Max number of iterations (inner sinkhorn solver)
203203
stopInnerThr : float, optional
204204
Stop threshold on error (inner sinkhorn solver) (>0)
205+
eps: float, optional (default=1e-12)
206+
Small value to avoid division by zero
205207
verbose : bool, optional
206208
Print information along iterations
207209
log : bool, optional
@@ -241,19 +243,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
241243

242244
def f(G):
243245
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
244-
return nx.norm(G_split * unroll_labels_idx, axis=1).sum()
245-
246-
lstlab = nx.unique(labels_a)
246+
return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1))
247247

248248
def df(G):
249-
W = nx.zeros(G.shape, type_as=G)
250-
for i in range(G.shape[1]):
251-
for lab in lstlab:
252-
temp = G[labels_a == lab, i]
253-
n = nx.norm(temp)
254-
if n:
255-
W[labels_a == lab, i] = temp / n
256-
return W
249+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx
250+
W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True)
251+
G_norm = G_split / nx.clip(W, eps, None)
252+
return nx.sum(G_norm, axis=2).T
257253

258254
return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
259255
numInnerItermax=numInnerItermax, stopThr=stopInnerThr,

test/test_da.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -802,30 +802,48 @@ def test_emd_laplace_class(nx):
802802
assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt))))
803803

804804

805-
def test_sinkhorn_l1l2_gl_cost_vectorized():
805+
def test_sinkhorn_l1l2_gl_cost_vectorized(nx):
806806
n_samples, n_labels = 150, 3
807807
rng = np.random.RandomState(42)
808808
G = rng.rand(n_samples, n_samples)
809809
labels_a = rng.randint(n_labels, size=(n_samples,))
810+
G, labels_a = nx.from_numpy(G), nx.from_numpy(labels_a)
810811

811812
# previously used implementation for the cost estimator
812-
lstlab = np.unique(labels_a)
813+
lstlab = nx.unique(labels_a)
813814

814815
def f(G):
815816
res = 0
816817
for i in range(G.shape[1]):
817818
for lab in lstlab:
818819
temp = G[labels_a == lab, i]
819-
res += np.linalg.norm(temp)
820+
res += nx.norm(temp)
820821
return res
821822

823+
def df(G):
824+
W = nx.zeros(G.shape, type_as=G)
825+
for i in range(G.shape[1]):
826+
for lab in lstlab:
827+
temp = G[labels_a == lab, i]
828+
n = nx.norm(temp)
829+
if n:
830+
W[labels_a == lab, i] = temp / n
831+
return W
832+
822833
# new vectorized implementation for the cost estimator
823-
lstlab, lstlab_idx = np.unique(labels_a, return_inverse=True)
824-
n_samples = lstlab.shape[0]
825-
midx = np.eye(n_samples, dtype='int32')[None, lstlab_idx]
834+
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
835+
n_labels = labels_u.shape[0]
836+
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]
826837

827838
def f2(G):
828-
G_split = np.repeat(G.T[:, :, None], n_samples, axis=2)
829-
return np.linalg.norm(G_split * midx, axis=1).sum()
839+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
840+
return nx.sum(nx.norm(G_split * unroll_labels_idx, axis=1))
841+
842+
def df2(G):
843+
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2) * unroll_labels_idx
844+
W = nx.norm(G_split * unroll_labels_idx, axis=1, keepdims=True)
845+
G_norm = G_split / nx.clip(W, 1e-12, None)
846+
return nx.sum(G_norm, axis=2).T
830847

831848
assert np.allclose(f(G), f2(G))
849+
assert np.allclose(df(G), df2(G))

0 commit comments

Comments
 (0)