Skip to content

Commit c0d8391

Browse files
committed
fix test
1 parent 60cd0d7 commit c0d8391

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

ot/unbalanced.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
443443
v = nx.ones(dim_b, type_as=M)
444444
else:
445445
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
446+
if not n_hists:
447+
u, v = u.reshape(-1), v.reshape(-1)
446448

447449
if reg_type == "kl":
448450
K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :]
@@ -652,6 +654,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
652654
v = nx.ones(dim_b, type_as=M)
653655
else:
654656
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
657+
if not n_hists:
658+
u, v = u.reshape(-1), v.reshape(-1)
655659

656660
if reg_type == "kl":
657661
log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :]

test/test_unbalanced.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,16 @@ def test_unbalanced_warmstart(nx, method, reg_type):
104104
reg_type=reg_type, warmstart=None, verbose=True
105105
)
106106

107-
# dim_a, dim_b = M.shape
108-
# warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M))
109-
# G, log = ot.unbalanced.sinkhorn_unbalanced(
110-
# a, b, M, reg=epsilon, reg_m=reg_m, method=method,
111-
# reg_type=reg_type, warmstart=warmstart, log=True, verbose=True
112-
# )
113-
# loss = ot.unbalanced.sinkhorn_unbalanced2(
114-
# a, b, M, reg=epsilon, reg_m=reg_m, method=method,
115-
# reg_type=reg_type, warmstart=warmstart, verbose=True
116-
# )
107+
dim_a, dim_b = M.shape
108+
warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M))
109+
G, log = ot.unbalanced.sinkhorn_unbalanced(
110+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
111+
reg_type=reg_type, warmstart=warmstart, log=True, verbose=True
112+
)
113+
loss = ot.unbalanced.sinkhorn_unbalanced2(
114+
a, b, M, reg=epsilon, reg_m=reg_m, method=method,
115+
reg_type=reg_type, warmstart=warmstart, verbose=True
116+
)
117117

118118
_, log = ot.lp.emd(a, b, M, log=True)
119119
warmstart1 = (log["u"], log["v"])
@@ -126,19 +126,19 @@ def test_unbalanced_warmstart(nx, method, reg_type):
126126
reg_type=reg_type, warmstart=warmstart1, verbose=True
127127
)
128128

129-
# np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5)
129+
np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5)
130130
np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5)
131131

132-
# np.testing.assert_allclose(
133-
# nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05)
134-
# np.testing.assert_allclose(
135-
# nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05)
132+
np.testing.assert_allclose(
133+
nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05)
134+
np.testing.assert_allclose(
135+
nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05)
136136
np.testing.assert_allclose(
137137
nx.to_numpy(log0["logu"]), nx.to_numpy(log1["logu"]), atol=1e-05)
138138
np.testing.assert_allclose(
139139
nx.to_numpy(log0["logv"]), nx.to_numpy(log1["logv"]), atol=1e-05)
140140

141-
# np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05)
141+
np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05)
142142
np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05)
143143

144144

0 commit comments

Comments
 (0)