@@ -88,41 +88,58 @@ def test_unbalanced_warmstart(nx, method, reg_type):
88
88
89
89
x = rng .randn (n , 2 )
90
90
a = ot .utils .unif (n )
91
-
92
- # make dists unbalanced
93
- b = ot .utils .unif (n ) * 1.5
91
+ b = ot .utils .unif (n )
94
92
M = ot .dist (x , x )
95
93
a , b , M = nx .from_numpy (a , b , M )
96
94
97
95
epsilon = 1.
98
96
reg_m = 1.
99
97
100
- dim_a , dim_b = M .shape
101
- warmstart = (nx .zeros (dim_a , type_as = M ), nx .zeros (dim_b , type_as = M ))
102
- G , log = ot .unbalanced .sinkhorn_unbalanced (
98
+ G0 , log0 = ot .unbalanced .sinkhorn_unbalanced (
103
99
a , b , M , reg = epsilon , reg_m = reg_m , method = method ,
104
- reg_type = reg_type , warmstart = warmstart , log = True , verbose = True
100
+ reg_type = reg_type , warmstart = None , log = True , verbose = True
105
101
)
106
- loss = nx . to_numpy ( ot .unbalanced .sinkhorn_unbalanced2 (
102
+ loss0 = ot .unbalanced .sinkhorn_unbalanced2 (
107
103
a , b , M , reg = epsilon , reg_m = reg_m , method = method ,
108
- reg_type = reg_type , warmstart = warmstart , verbose = True
109
- ))
104
+ reg_type = reg_type , warmstart = None , verbose = True
105
+ )
110
106
111
- G0 , log0 = ot .unbalanced .sinkhorn_unbalanced (
107
+ # dim_a, dim_b = M.shape
108
+ # warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M))
109
+ # G, log = ot.unbalanced.sinkhorn_unbalanced(
110
+ # a, b, M, reg=epsilon, reg_m=reg_m, method=method,
111
+ # reg_type=reg_type, warmstart=warmstart, log=True, verbose=True
112
+ # )
113
+ # loss = ot.unbalanced.sinkhorn_unbalanced2(
114
+ # a, b, M, reg=epsilon, reg_m=reg_m, method=method,
115
+ # reg_type=reg_type, warmstart=warmstart, verbose=True
116
+ # )
117
+
118
+ _ , log = ot .lp .emd (a , b , M , log = True )
119
+ warmstart1 = (log ["u" ], log ["v" ])
120
+ G1 , log1 = ot .unbalanced .sinkhorn_unbalanced (
112
121
a , b , M , reg = epsilon , reg_m = reg_m , method = method ,
113
- reg_type = reg_type , warmstart = None , log = True , verbose = True
122
+ reg_type = reg_type , warmstart = warmstart1 , log = True , verbose = True
114
123
)
115
- loss0 = nx . to_numpy ( ot .unbalanced .sinkhorn_unbalanced2 (
124
+ loss1 = ot .unbalanced .sinkhorn_unbalanced2 (
116
125
a , b , M , reg = epsilon , reg_m = reg_m , method = method ,
117
- reg_type = reg_type , warmstart = None , verbose = True
118
- ))
126
+ reg_type = reg_type , warmstart = warmstart1 , verbose = True
127
+ )
128
+
129
+ # np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5)
130
+ np .testing .assert_allclose (nx .to_numpy (loss0 ), nx .to_numpy (loss1 ), atol = 1e-5 )
119
131
120
- np .testing .assert_allclose (loss , loss0 , atol = 1e-5 )
132
+ # np.testing.assert_allclose(
133
+ # nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05)
134
+ # np.testing.assert_allclose(
135
+ # nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05)
121
136
np .testing .assert_allclose (
122
- nx .to_numpy (log ["logu" ]), nx .to_numpy (log0 ["logu" ]), atol = 1e-05 )
137
+ nx .to_numpy (log0 ["logu" ]), nx .to_numpy (log1 ["logu" ]), atol = 1e-05 )
123
138
np .testing .assert_allclose (
124
- nx .to_numpy (log ["logv" ]), nx .to_numpy (log0 ["logv" ]), atol = 1e-05 )
125
- np .testing .assert_allclose (nx .to_numpy (G ), nx .to_numpy (G0 ), atol = 1e-05 )
139
+ nx .to_numpy (log0 ["logv" ]), nx .to_numpy (log1 ["logv" ]), atol = 1e-05 )
140
+
141
+ # np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05)
142
+ np .testing .assert_allclose (nx .to_numpy (G0 ), nx .to_numpy (G1 ), atol = 1e-05 )
126
143
127
144
128
145
@pytest .mark .parametrize ("method,reg_m" , itertools .product (["sinkhorn" , "sinkhorn_stabilized" ], [1 , float ("inf" )]))
0 commit comments