@@ -277,30 +277,55 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
277
277
ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
278
278
279
279
"""
280
- b = list_to_array (b )
280
+ M , a , b = list_to_array (M , a , b )
281
+ nx = get_backend (M , a , b )
282
+
281
283
if len (b .shape ) < 2 :
282
- b = b [:, None ]
284
+ if method .lower () == 'sinkhorn' :
285
+ res = sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , reg_type ,
286
+ warmstart , numItermax = numItermax ,
287
+ stopThr = stopThr , verbose = verbose ,
288
+ log = log , ** kwargs )
289
+
290
+ elif method .lower () == 'sinkhorn_stabilized' :
291
+ res = sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m , reg_type ,
292
+ warmstart , numItermax = numItermax ,
293
+ stopThr = stopThr , verbose = verbose ,
294
+ log = log , ** kwargs )
295
+ elif method .lower () in ['sinkhorn_reg_scaling' ]:
296
+ warnings .warn ('Method not implemented yet. Using classic Sinkhorn-Knopp' )
297
+ res = sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , reg_type ,
298
+ warmstart , numItermax = numItermax ,
299
+ stopThr = stopThr , verbose = verbose ,
300
+ log = log , ** kwargs )
301
+ else :
302
+ raise ValueError ('Unknown method %s.' % method )
283
303
284
- if method .lower () == 'sinkhorn' :
285
- return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , reg_type ,
286
- warmstart , numItermax = numItermax ,
287
- stopThr = stopThr , verbose = verbose ,
288
- log = log , ** kwargs )
304
+ if log :
305
+ return nx .sum (M * res [0 ]), res [1 ]
306
+ else :
307
+ return nx .sum (M * res )
289
308
290
- elif method .lower () == 'sinkhorn_stabilized' :
291
- return sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m , reg_type ,
292
- warmstart , numItermax = numItermax ,
293
- stopThr = stopThr ,
294
- verbose = verbose ,
295
- log = log , ** kwargs )
296
- elif method .lower () in ['sinkhorn_reg_scaling' ]:
297
- warnings .warn ('Method not implemented yet. Using classic Sinkhorn-Knopp' )
298
- return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , reg_type ,
299
- warmstart , numItermax = numItermax ,
300
- stopThr = stopThr , verbose = verbose ,
301
- log = log , ** kwargs )
302
309
else :
303
- raise ValueError ('Unknown method %s.' % method )
310
+ if method .lower () == 'sinkhorn' :
311
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , reg_type ,
312
+ warmstart , numItermax = numItermax ,
313
+ stopThr = stopThr , verbose = verbose ,
314
+ log = log , ** kwargs )
315
+
316
+ elif method .lower () == 'sinkhorn_stabilized' :
317
+ return sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m , reg_type ,
318
+ warmstart , numItermax = numItermax ,
319
+ stopThr = stopThr , verbose = verbose ,
320
+ log = log , ** kwargs )
321
+ elif method .lower () in ['sinkhorn_reg_scaling' ]:
322
+ warnings .warn ('Method not implemented yet. Using classic Sinkhorn-Knopp' )
323
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , reg_type ,
324
+ warmstart , numItermax = numItermax ,
325
+ stopThr = stopThr , verbose = verbose ,
326
+ log = log , ** kwargs )
327
+ else :
328
+ raise ValueError ('Unknown method %s.' % method )
304
329
305
330
306
331
def sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , reg_type = "entropy" ,
@@ -443,8 +468,6 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
443
468
v = nx .ones (dim_b , type_as = M )
444
469
else :
445
470
u , v = nx .exp (warmstart [0 ]), nx .exp (warmstart [1 ])
446
- if not n_hists :
447
- u , v = u .reshape (- 1 ), v .reshape (- 1 )
448
471
449
472
if reg_type == "kl" :
450
473
K = nx .exp (- M / reg ) * a .reshape (- 1 )[:, None ] * b .reshape (- 1 )[None , :]
@@ -654,8 +677,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
654
677
v = nx .ones (dim_b , type_as = M )
655
678
else :
656
679
u , v = nx .exp (warmstart [0 ]), nx .exp (warmstart [1 ])
657
- if not n_hists :
658
- u , v = u .reshape (- 1 ), v .reshape (- 1 )
659
680
660
681
if reg_type == "kl" :
661
682
log_ab = nx .log (a + 1e-16 ).reshape (- 1 )[:, None ] + nx .log (b + 1e-16 ).reshape (- 1 )[None , :]
0 commit comments