@@ -140,6 +140,36 @@ def test_unbalanced_warmstart(nx, method, reg_type):
140
140
np .testing .assert_allclose (nx .to_numpy (loss0 ), nx .to_numpy (loss1 ), atol = 1e-5 )
141
141
142
142
143
+ @pytest .mark .parametrize ("method,reg_type, log" , itertools .product (["sinkhorn" , "sinkhorn_stabilized" , "sinkhorn_reg_scaling" ], ["kl" , "entropy" ], [True , False ]))
144
+ def test_sinkhorn_unbalanced2 (nx , method , reg_type , log ):
145
+ n = 100
146
+ rng = np .random .RandomState (42 )
147
+
148
+ x = rng .randn (n , 2 )
149
+ a = ot .utils .unif (n )
150
+
151
+ # make dists unbalanced
152
+ b = ot .utils .unif (n ) * 1.5
153
+ M = ot .dist (x , x )
154
+ a , b , M = nx .from_numpy (a , b , M )
155
+
156
+ epsilon = 1.
157
+ reg_m = 1.
158
+
159
+ loss = nx .to_numpy (ot .unbalanced .sinkhorn_unbalanced2 (
160
+ a , b , M , reg = epsilon , reg_m = reg_m , method = method ,
161
+ reg_type = reg_type , log = False , verbose = True
162
+ ))
163
+
164
+ res = ot .unbalanced .sinkhorn_unbalanced2 (
165
+ a , b , M , reg = epsilon , reg_m = reg_m , method = method ,
166
+ reg_type = reg_type , log = log , verbose = True
167
+ )
168
+ loss0 = res [0 ] if log else res
169
+
170
+ np .testing .assert_allclose (nx .to_numpy (loss ), nx .to_numpy (loss0 ), atol = 1e-5 )
171
+
172
+
143
173
@pytest .mark .parametrize ("method,reg_m" , itertools .product (["sinkhorn" , "sinkhorn_stabilized" , "sinkhorn_reg_scaling" ], [1 , float ("inf" )]))
144
174
def test_unbalanced_relaxation_parameters (nx , method , reg_m ):
145
175
# test generalized sinkhorn for unbalanced OT
@@ -202,11 +232,10 @@ def test_unbalanced_multiple_inputs(nx, method):
202
232
203
233
a , b , M = nx .from_numpy (a , b , M )
204
234
205
- loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon ,
206
- reg_m = reg_m ,
207
- method = method ,
208
- log = True ,
209
- verbose = True )
235
+ G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon ,
236
+ reg_m = reg_m , method = method ,
237
+ log = True , verbose = True )
238
+
210
239
# check fixed point equations
211
240
# in log-domain
212
241
fi = reg_m / (reg_m + epsilon )
0 commit comments