Skip to content

Commit dcb19c6

Browse files
committed
Fix flake
1 parent f4ca25a commit dcb19c6

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

ot/da.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
238238
labels_u, labels_idx = nx.unique(labels_a, return_inverse=True)
239239
n_labels = labels_u.shape[0]
240240
unroll_labels_idx = nx.eye(n_labels, type_as=labels_u)[None, labels_idx]
241+
241242
def f(G):
242243
G_split = nx.repeat(G.T[:, :, None], n_labels, axis=2)
243244
return nx.norm(G_split * unroll_labels_idx, axis=1).sum()
244245

245246
lstlab = nx.unique(labels_a)
247+
246248
def df(G):
247249
W = nx.zeros(G.shape, type_as=G)
248250
for i in range(G.shape[1]):

test/test_da.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,7 @@ def test_sinkhorn_l1l2_gl_cost_vectorized():
810810

811811
# previously used implementation for the cost estimator
812812
lstlab = np.unique(labels_a)
813+
813814
def f(G):
814815
res = 0
815816
for i in range(G.shape[1]):
@@ -822,8 +823,9 @@ def f(G):
822823
lstlab, lstlab_idx = np.unique(labels_a, return_inverse=True)
823824
n_samples = lstlab.shape[0]
824825
midx = np.eye(n_samples, dtype='int32')[None, lstlab_idx]
826+
825827
def f2(G):
826828
G_split = np.repeat(G.T[:, :, None], n_samples, axis=2)
827829
return np.linalg.norm(G_split * midx, axis=1).sum()
828830

829-
assert np.allclose(f(G), f2(G))
831+
assert np.allclose(f(G), f2(G))

0 commit comments

Comments
 (0)