Skip to content

Commit 1bccd47

Browse files
committed
fix bug in doctest
1 parent e8298d0 commit 1bccd47

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

ot/unbalanced.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
244244
--------
245245
246246
>>> import ot
247+
>>> import numpy as np
247248
>>> a=[.5, .10]
248249
>>> b=[.5, .5]
249250
>>> M=[[0., 1.],[1., 0.]]
250-
>>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.)
251-
0.3191285827553562
251+
>>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8)
252+
0.31912858
252253
253254
.. _references-sinkhorn-unbalanced2:
254255
References

test/test_unbalanced.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ot.unbalanced import barycenter_unbalanced
1515

1616

17-
@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"]))
17+
@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"]))
1818
def test_unbalanced_convergence(nx, method, reg_type):
1919
# test generalized sinkhorn for unbalanced OT
2020
n = 100
@@ -78,7 +78,7 @@ def test_unbalanced_convergence(nx, method, reg_type):
7878
np.testing.assert_allclose(G_np, nx.to_numpy(G))
7979

8080

81-
@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized"], ["kl", "entropy"]))
81+
@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"]))
8282
def test_unbalanced_warmstart(nx, method, reg_type):
8383
# test generalized sinkhorn for unbalanced OT
8484
n = 100
@@ -140,7 +140,7 @@ def test_unbalanced_warmstart(nx, method, reg_type):
140140
np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5)
141141

142142

143-
@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")]))
143+
@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [1, float("inf")]))
144144
def test_unbalanced_relaxation_parameters(nx, method, reg_m):
145145
# test generalized sinkhorn for unbalanced OT
146146
n = 100
@@ -184,7 +184,7 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m):
184184
nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05)
185185

186186

187-
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
187+
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"])
188188
def test_unbalanced_multiple_inputs(nx, method):
189189
# test generalized sinkhorn for unbalanced OT
190190
n = 100

0 commit comments

Comments
 (0)