@@ -104,16 +104,16 @@ def test_unbalanced_warmstart(nx, method, reg_type):
104
104
reg_type = reg_type , warmstart = None , verbose = True
105
105
)
106
106
107
- # dim_a, dim_b = M.shape
108
- # warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M))
109
- # G, log = ot.unbalanced.sinkhorn_unbalanced(
110
- # a, b, M, reg=epsilon, reg_m=reg_m, method=method,
111
- # reg_type=reg_type, warmstart=warmstart, log=True, verbose=True
112
- # )
113
- # loss = ot.unbalanced.sinkhorn_unbalanced2(
114
- # a, b, M, reg=epsilon, reg_m=reg_m, method=method,
115
- # reg_type=reg_type, warmstart=warmstart, verbose=True
116
- # )
107
+ dim_a , dim_b = M .shape
108
+ warmstart = (nx .zeros (dim_a , type_as = M ), nx .zeros (dim_b , type_as = M ))
109
+ G , log = ot .unbalanced .sinkhorn_unbalanced (
110
+ a , b , M , reg = epsilon , reg_m = reg_m , method = method ,
111
+ reg_type = reg_type , warmstart = warmstart , log = True , verbose = True
112
+ )
113
+ loss = ot .unbalanced .sinkhorn_unbalanced2 (
114
+ a , b , M , reg = epsilon , reg_m = reg_m , method = method ,
115
+ reg_type = reg_type , warmstart = warmstart , verbose = True
116
+ )
117
117
118
118
_ , log = ot .lp .emd (a , b , M , log = True )
119
119
warmstart1 = (log ["u" ], log ["v" ])
@@ -126,19 +126,19 @@ def test_unbalanced_warmstart(nx, method, reg_type):
126
126
reg_type = reg_type , warmstart = warmstart1 , verbose = True
127
127
)
128
128
129
- # np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5)
129
+ np .testing .assert_allclose (nx .to_numpy (loss ), nx .to_numpy (loss0 ), atol = 1e-5 )
130
130
np .testing .assert_allclose (nx .to_numpy (loss0 ), nx .to_numpy (loss1 ), atol = 1e-5 )
131
131
132
- # np.testing.assert_allclose(
133
- # nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05)
134
- # np.testing.assert_allclose(
135
- # nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05)
132
+ np .testing .assert_allclose (
133
+ nx .to_numpy (log ["logu" ]), nx .to_numpy (log0 ["logu" ]), atol = 1e-05 )
134
+ np .testing .assert_allclose (
135
+ nx .to_numpy (log ["logv" ]), nx .to_numpy (log0 ["logv" ]), atol = 1e-05 )
136
136
np .testing .assert_allclose (
137
137
nx .to_numpy (log0 ["logu" ]), nx .to_numpy (log1 ["logu" ]), atol = 1e-05 )
138
138
np .testing .assert_allclose (
139
139
nx .to_numpy (log0 ["logv" ]), nx .to_numpy (log1 ["logv" ]), atol = 1e-05 )
140
140
141
- # np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05)
141
+ np .testing .assert_allclose (nx .to_numpy (G ), nx .to_numpy (G0 ), atol = 1e-05 )
142
142
np .testing .assert_allclose (nx .to_numpy (G0 ), nx .to_numpy (G1 ), atol = 1e-05 )
143
143
144
144
0 commit comments