Skip to content

Commit 60cd0d7

Browse files
committed
fix test
1 parent 51d6e24 commit 60cd0d7

File tree

1 file changed

+36
-19
lines changed

1 file changed

+36
-19
lines changed

test/test_unbalanced.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,41 +88,58 @@ def test_unbalanced_warmstart(nx, method, reg_type):
8888

8989
x = rng.randn(n, 2)
9090
a = ot.utils.unif(n)
91-
92-
# make dists unbalanced
93-
b = ot.utils.unif(n) * 1.5
91+
b = ot.utils.unif(n)
9492
M = ot.dist(x, x)
9593
a, b, M = nx.from_numpy(a, b, M)
9694

9795
epsilon = 1.
9896
reg_m = 1.
9997

100-
dim_a, dim_b = M.shape
101-
warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M))
102-
G, log = ot.unbalanced.sinkhorn_unbalanced(
98+
G0, log0 = ot.unbalanced.sinkhorn_unbalanced(
10399
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
104-
reg_type=reg_type, warmstart=warmstart, log=True, verbose=True
100+
reg_type=reg_type, warmstart=None, log=True, verbose=True
105101
)
106-
loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
102+
loss0 = ot.unbalanced.sinkhorn_unbalanced2(
107103
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
108-
reg_type=reg_type, warmstart=warmstart, verbose=True
109-
))
104+
reg_type=reg_type, warmstart=None, verbose=True
105+
)
110106

111-
G0, log0 = ot.unbalanced.sinkhorn_unbalanced(
107+
# dim_a, dim_b = M.shape
108+
# warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M))
109+
# G, log = ot.unbalanced.sinkhorn_unbalanced(
110+
# a, b, M, reg=epsilon, reg_m=reg_m, method=method,
111+
# reg_type=reg_type, warmstart=warmstart, log=True, verbose=True
112+
# )
113+
# loss = ot.unbalanced.sinkhorn_unbalanced2(
114+
# a, b, M, reg=epsilon, reg_m=reg_m, method=method,
115+
# reg_type=reg_type, warmstart=warmstart, verbose=True
116+
# )
117+
118+
_, log = ot.lp.emd(a, b, M, log=True)
119+
warmstart1 = (log["u"], log["v"])
120+
G1, log1 = ot.unbalanced.sinkhorn_unbalanced(
112121
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
113-
reg_type=reg_type, warmstart=None, log=True, verbose=True
122+
reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True
114123
)
115-
loss0 = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
124+
loss1 = ot.unbalanced.sinkhorn_unbalanced2(
116125
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
117-
reg_type=reg_type, warmstart=None, verbose=True
118-
))
126+
reg_type=reg_type, warmstart=warmstart1, verbose=True
127+
)
128+
129+
# np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5)
130+
np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5)
119131

120-
np.testing.assert_allclose(loss, loss0, atol=1e-5)
132+
# np.testing.assert_allclose(
133+
# nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05)
134+
# np.testing.assert_allclose(
135+
# nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05)
121136
np.testing.assert_allclose(
122-
nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05)
137+
nx.to_numpy(log0["logu"]), nx.to_numpy(log1["logu"]), atol=1e-05)
123138
np.testing.assert_allclose(
124-
nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05)
125-
np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05)
139+
nx.to_numpy(log0["logv"]), nx.to_numpy(log1["logv"]), atol=1e-05)
140+
141+
# np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05)
142+
np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05)
126143

127144

128145
@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized"], [1, float("inf")]))

0 commit comments

Comments
 (0)