Skip to content

Commit 4db194a

Browse files
committed
fix bug in test
1 parent c0d8391 commit 4db194a

File tree

2 files changed

+50
-31
lines changed

2 files changed

+50
-31
lines changed

ot/unbalanced.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -277,30 +277,55 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
277277
ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
278278
279279
"""
280-
b = list_to_array(b)
280+
M, a, b = list_to_array(M, a, b)
281+
nx = get_backend(M, a, b)
282+
281283
if len(b.shape) < 2:
282-
b = b[:, None]
284+
if method.lower() == 'sinkhorn':
285+
res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
286+
warmstart, numItermax=numItermax,
287+
stopThr=stopThr, verbose=verbose,
288+
log=log, **kwargs)
289+
290+
elif method.lower() == 'sinkhorn_stabilized':
291+
res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type,
292+
warmstart, numItermax=numItermax,
293+
stopThr=stopThr, verbose=verbose,
294+
log=log, **kwargs)
295+
elif method.lower() in ['sinkhorn_reg_scaling']:
296+
warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp')
297+
res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
298+
warmstart, numItermax=numItermax,
299+
stopThr=stopThr, verbose=verbose,
300+
log=log, **kwargs)
301+
else:
302+
raise ValueError('Unknown method %s.' % method)
283303

284-
if method.lower() == 'sinkhorn':
285-
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
286-
warmstart, numItermax=numItermax,
287-
stopThr=stopThr, verbose=verbose,
288-
log=log, **kwargs)
304+
if log:
305+
return nx.sum(M * res[0]), res[1]
306+
else:
307+
return nx.sum(M * res)
289308

290-
elif method.lower() == 'sinkhorn_stabilized':
291-
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type,
292-
warmstart, numItermax=numItermax,
293-
stopThr=stopThr,
294-
verbose=verbose,
295-
log=log, **kwargs)
296-
elif method.lower() in ['sinkhorn_reg_scaling']:
297-
warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp')
298-
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
299-
warmstart, numItermax=numItermax,
300-
stopThr=stopThr, verbose=verbose,
301-
log=log, **kwargs)
302309
else:
303-
raise ValueError('Unknown method %s.' % method)
310+
if method.lower() == 'sinkhorn':
311+
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
312+
warmstart, numItermax=numItermax,
313+
stopThr=stopThr, verbose=verbose,
314+
log=log, **kwargs)
315+
316+
elif method.lower() == 'sinkhorn_stabilized':
317+
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type,
318+
warmstart, numItermax=numItermax,
319+
stopThr=stopThr, verbose=verbose,
320+
log=log, **kwargs)
321+
elif method.lower() in ['sinkhorn_reg_scaling']:
322+
warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp')
323+
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
324+
warmstart, numItermax=numItermax,
325+
stopThr=stopThr, verbose=verbose,
326+
log=log, **kwargs)
327+
else:
328+
raise ValueError('Unknown method %s.' % method)
304329

305330

306331
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
@@ -443,8 +468,6 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
443468
v = nx.ones(dim_b, type_as=M)
444469
else:
445470
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
446-
if not n_hists:
447-
u, v = u.reshape(-1), v.reshape(-1)
448471

449472
if reg_type == "kl":
450473
K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :]
@@ -654,8 +677,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
654677
v = nx.ones(dim_b, type_as=M)
655678
else:
656679
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
657-
if not n_hists:
658-
u, v = u.reshape(-1), v.reshape(-1)
659680

660681
if reg_type == "kl":
661682
log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :]

test/test_unbalanced.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616

1717
@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"]))
18-
# @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
1918
def test_unbalanced_convergence(nx, method, reg_type):
2019
# test generalized sinkhorn for unbalanced OT
2120
n = 100
@@ -80,7 +79,6 @@ def test_unbalanced_convergence(nx, method, reg_type):
8079

8180

8281
@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"]))
83-
# @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
8482
def test_unbalanced_warmstart(nx, method, reg_type):
8583
# test generalized sinkhorn for unbalanced OT
8684
n = 100
@@ -115,8 +113,8 @@ def test_unbalanced_warmstart(nx, method, reg_type):
115113
reg_type=reg_type, warmstart=warmstart, verbose=True
116114
)
117115

118-
_, log = ot.lp.emd(a, b, M, log=True)
119-
warmstart1 = (log["u"], log["v"])
116+
_, log_emd = ot.lp.emd(a, b, M, log=True)
117+
warmstart1 = (log_emd["u"], log_emd["v"])
120118
G1, log1 = ot.unbalanced.sinkhorn_unbalanced(
121119
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
122120
reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True
@@ -126,9 +124,6 @@ def test_unbalanced_warmstart(nx, method, reg_type):
126124
reg_type=reg_type, warmstart=warmstart1, verbose=True
127125
)
128126

129-
np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5)
130-
np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5)
131-
132127
np.testing.assert_allclose(
133128
nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05)
134129
np.testing.assert_allclose(
@@ -141,6 +136,9 @@ def test_unbalanced_warmstart(nx, method, reg_type):
141136
np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05)
142137
np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05)
143138

139+
np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5)
140+
np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5)
141+
144142

145143
@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")]))
146144
def test_unbalanced_relaxation_parameters(nx, method, reg_m):

0 commit comments

Comments
 (0)