19
19
from .utils import list_to_array , get_parameter_pair
20
20
21
21
22
- def sinkhorn_unbalanced (a , b , M , reg , reg_m , reg_type = "entropy" , warmstart = None ,
23
- method = 'sinkhorn' , numItermax = 1000 ,
22
+ def sinkhorn_unbalanced (a , b , M , reg , reg_m , method = 'sinkhorn' ,
23
+ reg_type = "entropy" , warmstart = None , numItermax = 1000 ,
24
24
stopThr = 1e-6 , verbose = False , log = False , ** kwargs ):
25
25
r"""
26
26
Solve the unbalanced entropic regularization optimal transport problem
@@ -67,6 +67,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None,
67
67
For semi-relaxed case, use either
68
68
`reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
69
69
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
70
+ method : str
71
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
72
+ 'sinkhorn_reg_scaling', see those function for specific parameters
70
73
reg_type : string, optional
71
74
Regularizer term. Can take two values:
72
75
'entropy' (negative entropy)
@@ -75,10 +78,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None,
75
78
:math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
76
79
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
77
80
Initialization of dual potentials. If provided, the dual potentials should be given
78
- (that is the logarithm of the u,v sinkhorn scaling vectors).s
79
- method : str
80
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
81
- 'sinkhorn_reg_scaling', see those function for specific parameters
81
+ (that is the logarithm of the u,v sinkhorn scaling vectors).
82
82
numItermax : int, optional
83
83
Max number of iterations
84
84
stopThr : float, optional
@@ -165,8 +165,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None,
165
165
raise ValueError ("Unknown method '%s'." % method )
166
166
167
167
168
- def sinkhorn_unbalanced2 (a , b , M , reg , reg_m , reg_type = "entropy" , warmstart = None ,
169
- method = 'sinkhorn' , numItermax = 1000 ,
168
+ def sinkhorn_unbalanced2 (a , b , M , reg , reg_m , method = 'sinkhorn' ,
169
+ reg_type = "entropy" , warmstart = None , numItermax = 1000 ,
170
170
stopThr = 1e-6 , verbose = False , log = False , ** kwargs ):
171
171
r"""
172
172
Solve the entropic regularization unbalanced optimal transport problem and
@@ -212,6 +212,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None
212
212
For semi-relaxed case, use either
213
213
`reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
214
214
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
215
+ method : str
216
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
217
+ 'sinkhorn_reg_scaling', see those function for specific parameterss
215
218
reg_type : string, optional
216
219
Regularizer term. Can take two values:
217
220
'entropy' (negative entropy)
@@ -221,9 +224,6 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None
221
224
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
222
225
Initialization of dual potentials. If provided, the dual potentials should be given
223
226
(that is the logarithm of the u,v sinkhorn scaling vectors).
224
- method : str
225
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
226
- 'sinkhorn_reg_scaling', see those function for specific parameters
227
227
numItermax : int, optional
228
228
Max number of iterations
229
229
stopThr : float, optional
@@ -435,12 +435,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
435
435
# distances
436
436
if warmstart is None :
437
437
if n_hists :
438
- u = nx .ones ((dim_a , 1 ), type_as = M ) / dim_a
439
- v = nx .ones ((dim_b , n_hists ), type_as = M ) / dim_b
438
+ u = nx .ones ((dim_a , 1 ), type_as = M )
439
+ v = nx .ones ((dim_b , n_hists ), type_as = M )
440
440
a = a .reshape (dim_a , 1 )
441
441
else :
442
- u = nx .ones (dim_a , type_as = M ) / dim_a
443
- v = nx .ones (dim_b , type_as = M ) / dim_b
442
+ u = nx .ones (dim_a , type_as = M )
443
+ v = nx .ones (dim_b , type_as = M )
444
444
else :
445
445
u , v = nx .exp (warmstart [0 ]), nx .exp (warmstart [1 ])
446
446
@@ -644,12 +644,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
644
644
# distances
645
645
if warmstart is None :
646
646
if n_hists :
647
- u = nx .ones ((dim_a , n_hists ), type_as = M ) / dim_a
648
- v = nx .ones ((dim_b , n_hists ), type_as = M ) / dim_b
647
+ u = nx .ones ((dim_a , n_hists ), type_as = M )
648
+ v = nx .ones ((dim_b , n_hists ), type_as = M )
649
649
a = a .reshape (dim_a , 1 )
650
650
else :
651
- u = nx .ones (dim_a , type_as = M ) / dim_a
652
- v = nx .ones (dim_b , type_as = M ) / dim_b
651
+ u = nx .ones (dim_a , type_as = M )
652
+ v = nx .ones (dim_b , type_as = M )
653
653
else :
654
654
u , v = nx .exp (warmstart [0 ]), nx .exp (warmstart [1 ])
655
655
0 commit comments