Skip to content

Commit c2c0e96

Browse files
committed
add test for more coverage
1 parent 1bccd47 commit c2c0e96

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

test/test_unbalanced.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,36 @@ def test_unbalanced_warmstart(nx, method, reg_type):
140140
np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5)
141141

142142

143+
@pytest.mark.parametrize("method,reg_type, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"], [True, False]))
144+
def test_sinkhorn_unbalanced2(nx, method, reg_type, log):
145+
n = 100
146+
rng = np.random.RandomState(42)
147+
148+
x = rng.randn(n, 2)
149+
a = ot.utils.unif(n)
150+
151+
# make dists unbalanced
152+
b = ot.utils.unif(n) * 1.5
153+
M = ot.dist(x, x)
154+
a, b, M = nx.from_numpy(a, b, M)
155+
156+
epsilon = 1.
157+
reg_m = 1.
158+
159+
loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
160+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
161+
reg_type=reg_type, log=False, verbose=True
162+
))
163+
164+
res = ot.unbalanced.sinkhorn_unbalanced2(
165+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
166+
reg_type=reg_type, log=log, verbose=True
167+
)
168+
loss0 = res[0] if log else res
169+
170+
np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5)
171+
172+
143173
@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [1, float("inf")]))
144174
def test_unbalanced_relaxation_parameters(nx, method, reg_m):
145175
# test generalized sinkhorn for unbalanced OT
@@ -202,11 +232,10 @@ def test_unbalanced_multiple_inputs(nx, method):
202232

203233
a, b, M = nx.from_numpy(a, b, M)
204234

205-
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
206-
reg_m=reg_m,
207-
method=method,
208-
log=True,
209-
verbose=True)
235+
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
236+
reg_m=reg_m, method=method,
237+
log=True, verbose=True)
238+
210239
# check fixed point equations
211240
# in log-domain
212241
fi = reg_m / (reg_m + epsilon)

0 commit comments

Comments
 (0)