11
11
from .utils import unif , dist , get_lowrank_lazytensor
12
12
from .backend import get_backend
13
13
from .bregman import sinkhorn
14
- from sklearn .cluster import KMeans
15
14
15
+ # test if sklearn is installed for linux-minimal-deps
16
+ try :
17
+ import sklearn .cluster
18
+ sklearn_import = True
19
+ except ImportError :
20
+ sklearn_import = False
16
21
17
- def _init_lr_sinkhorn (X_s , X_t , a , b , rank , init , reg_init = None , random_state = None , nx = None ):
22
+
23
+ def _init_lr_sinkhorn (X_s , X_t , a , b , rank , init , reg_init , random_state , nx = None ):
18
24
"""
19
25
Implementation of different initialization strategies for the low rank sinkhorn solver (Q ,R, g).
20
26
This function is specific to lowrank_sinkhorn.
@@ -33,11 +39,11 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=No
33
39
Nonnegative rank of the OT plan.
34
40
init : str
35
41
Initialization strategy for Q, R and g. 'random', 'trivial' or 'kmeans'
36
- reg_init : float, optional. Default is None. (>0)
37
- Regularization term for a 'kmeans' init. If None, 1 is considered.
38
- random_state : default None
42
+ reg_init : float, optional.
43
+ Regularization term for a 'kmeans' init.
44
+ random_state : int, optional.
39
45
Random state for a "random" or 'kmeans' init strategy
40
- nx : default None
46
+ nx : optional, Default is None
41
47
POT backend
42
48
43
49
@@ -61,12 +67,6 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=No
61
67
if nx is None :
62
68
nx = get_backend (X_s , X_t , a , b )
63
69
64
- if reg_init is None :
65
- reg_init = 0.1
66
-
67
- if random_state is None :
68
- random_state = 49
69
-
70
70
ns = X_s .shape [0 ]
71
71
nt = X_t .shape [0 ]
72
72
r = rank
@@ -86,7 +86,7 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=No
86
86
R = nx .abs (nx .randn (nt , rank , type_as = X_s )) + 1
87
87
R = (R .T * (b / nx .sum (R , axis = 1 ))).T
88
88
89
- if init == "trivial " :
89
+ if init == "deterministic " :
90
90
# Init g
91
91
g = nx .ones (rank ) / rank
92
92
@@ -114,24 +114,28 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init=None, random_state=No
114
114
R = R1 + R2
115
115
116
116
if init == "kmeans" :
117
- # Init g
118
- g = nx .ones (rank , type_as = X_s ) / rank
119
-
120
- # Init Q
121
- kmeans_Xs = KMeans (n_clusters = rank , random_state = random_state , n_init = "auto" )
122
- kmeans_Xs .fit (X_s )
123
- Z_Xs = nx .from_numpy (kmeans_Xs .cluster_centers_ )
124
- C_Xs = dist (X_s , Z_Xs ) # shape (ns, rank)
125
- C_Xs = C_Xs / nx .max (C_Xs )
126
- Q = sinkhorn (a , g , C_Xs , reg = reg_init , numItermax = 10000 , stopThr = 1e-3 )
117
+ if sklearn_import :
118
+ # Init g
119
+ g = nx .ones (rank , type_as = X_s ) / rank
120
+
121
+ # Init Q
122
+ kmeans_Xs = sklearn .cluster .KMeans (n_clusters = rank , random_state = random_state , n_init = "auto" )
123
+ kmeans_Xs .fit (X_s )
124
+ Z_Xs = nx .from_numpy (kmeans_Xs .cluster_centers_ )
125
+ C_Xs = dist (X_s , Z_Xs ) # shape (ns, rank)
126
+ C_Xs = C_Xs / nx .max (C_Xs )
127
+ Q = sinkhorn (a , g , C_Xs , reg = reg_init , numItermax = 10000 , stopThr = 1e-3 )
128
+
129
+ # Init R
130
+ kmeans_Xt = sklearn .cluster .KMeans (n_clusters = rank , random_state = random_state , n_init = "auto" )
131
+ kmeans_Xt .fit (X_t )
132
+ Z_Xt = nx .from_numpy (kmeans_Xt .cluster_centers_ )
133
+ C_Xt = dist (X_t , Z_Xt ) # shape (nt, rank)
134
+ C_Xt = C_Xt / nx .max (C_Xt )
135
+ R = sinkhorn (b , g , C_Xt , reg = reg_init , numItermax = 10000 , stopThr = 1e-3 )
127
136
128
- # Init R
129
- kmeans_Xt = KMeans (n_clusters = rank , random_state = random_state , n_init = "auto" )
130
- kmeans_Xt .fit (X_t )
131
- Z_Xt = nx .from_numpy (kmeans_Xt .cluster_centers_ )
132
- C_Xt = dist (X_t , Z_Xt ) # shape (nt, rank)
133
- C_Xt = C_Xt / nx .max (C_Xt )
134
- R = sinkhorn (b , g , C_Xt , reg = reg_init , numItermax = 10000 , stopThr = 1e-3 )
137
+ else :
138
+ raise ImportError ("Scikit-learn should be installed to use the 'kmeans' init." )
135
139
136
140
return Q , R , g
137
141
@@ -306,7 +310,7 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N
306
310
307
311
308
312
def lowrank_sinkhorn (X_s , X_t , a = None , b = None , reg = 0 , rank = None , alpha = 1e-10 , rescale_cost = True ,
309
- init = "random" , reg_init = None , seed_init = None , gamma_init = "rescale" ,
313
+ init = "random" , reg_init = 1e-1 , seed_init = 49 , gamma_init = "rescale" ,
310
314
numItermax = 2000 , stopThr = 1e-7 , warn = True , log = False ):
311
315
r"""
312
316
Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints
@@ -347,10 +351,10 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, re
347
351
rescale_cost : bool, optional. Default is False
348
352
Rescale the low rank factorization of the sqeuclidean cost matrix
349
353
init : str, optional. Default is 'random'.
350
- Initialization strategy for the low rank couplings. 'random', 'trivial ' or 'kmeans'
351
- reg_init : float, optional. Default is None . (>0)
354
+ Initialization strategy for the low rank couplings. 'random', 'deterministic ' or 'kmeans'
355
+ reg_init : float, optional. Default is 1e-1 . (>0)
352
356
Regularization term for a 'kmeans' init. If None, 1 is considered.
353
- seed_init : int, optional. Default is None . (>0)
357
+ seed_init : int, optional. Default is 49 . (>0)
354
358
Random state for a 'random' or 'kmeans' init strategy.
355
359
gamma_init : str, optional. Default is "rescale".
356
360
Initialization strategy for gamma. 'rescale', or 'theory'
0 commit comments