@@ -493,7 +493,7 @@ class label
493
493
494
494
# pairwise distance
495
495
self .cost_ = dist (Xs , Xt , metric = self .metric )
496
- self .cost_ = cost_normalization (self .cost_ , self .norm )
496
+ self .cost_ , self . norm_cost_ = cost_normalization (self .cost_ , self .norm , return_value = True )
497
497
498
498
if (ys is not None ) and (yt is not None ):
499
499
@@ -1055,13 +1055,18 @@ class SinkhornTransport(BaseTransport):
1055
1055
The ground metric for the Wasserstein problem
1056
1056
norm : string, optional (default=None)
1057
1057
If given, normalize the ground metric to avoid numerical errors that
1058
- can occur with large metric values.
1058
+ can occur with large metric values. Accepted values are 'median',
1059
+ 'max', 'log' and 'loglog'.
1059
1060
distribution_estimation : callable, optional (defaults to the uniform)
1060
1061
The kind of distribution estimation to employ
1061
- out_of_sample_map : string, optional (default="ferradans ")
1062
+ out_of_sample_map : string, optional (default="continuous ")
1062
1063
The kind of out of sample mapping to apply to transport samples
1063
1064
from a domain into another one. Currently the only possible option is
1064
- "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhorntransport>`.
1065
+ "ferradans" which uses the nearest neighbor method proposed in :ref:`[6]
1066
+ <references-sinkhorntransport>` while "continuous" use the out of sample
1067
+ method from :ref:`[66]
1068
+ <references-sinkhorntransport>` and :ref:`[19]
1069
+ <references-sinkhorntransport>`.
1065
1070
limit_max: float, optional (default=np.infty)
1066
1071
Controls the semi supervised mode. Transport between labeled source
1067
1072
and target samples of different classes will exhibit an cost defined
@@ -1089,13 +1094,26 @@ class SinkhornTransport(BaseTransport):
1089
1094
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
1090
1095
Regularized discrete optimal transport. SIAM Journal on Imaging
1091
1096
Sciences, 7(3), 1853-1882.
1097
+
1098
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.
1099
+ & Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
1100
+ International Conference on Learning Representation (2018)
1101
+
1102
+ .. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic
1103
+ estimation of optimal transport maps." arXiv preprint
1104
+ arXiv:2109.12004 (2021).
1105
+
1092
1106
"""
1093
1107
1094
- def __init__ (self , reg_e = 1. , method = "sinkhorn " , max_iter = 1000 ,
1108
+ def __init__ (self , reg_e = 1. , method = "sinkhorn_log " , max_iter = 1000 ,
1095
1109
tol = 10e-9 , verbose = False , log = False ,
1096
1110
metric = "sqeuclidean" , norm = None ,
1097
1111
distribution_estimation = distribution_estimation_uniform ,
1098
- out_of_sample_map = 'ferradans' , limit_max = np .infty ):
1112
+ out_of_sample_map = 'continuous' , limit_max = np .infty ):
1113
+
1114
+ if out_of_sample_map not in ['ferradans' , 'continuous' ]:
1115
+ raise ValueError ('Unknown out_of_sample_map method' )
1116
+
1099
1117
self .reg_e = reg_e
1100
1118
self .method = method
1101
1119
self .max_iter = max_iter
@@ -1135,6 +1153,12 @@ class label
1135
1153
1136
1154
super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
1137
1155
1156
+ if self .out_of_sample_map == 'continuous' :
1157
+ self .log = True
1158
+ if not self .method == 'sinkhorn_log' :
1159
+ self .method = 'sinkhorn_log'
1160
+ warnings .warn ("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'" )
1161
+
1138
1162
# coupling estimation
1139
1163
returned_ = sinkhorn (
1140
1164
a = self .mu_s , b = self .mu_t , M = self .cost_ , reg = self .reg_e ,
@@ -1150,6 +1174,120 @@ class label
1150
1174
1151
1175
return self
1152
1176
1177
+ def transform (self , Xs = None , ys = None , Xt = None , yt = None , batch_size = 128 ):
1178
+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
1179
+
1180
+ Parameters
1181
+ ----------
1182
+ Xs : array-like, shape (n_source_samples, n_features)
1183
+ The source input samples.
1184
+ ys : array-like, shape (n_source_samples,)
1185
+ The class labels for source samples
1186
+ Xt : array-like, shape (n_target_samples, n_features)
1187
+ The target input samples.
1188
+ yt : array-like, shape (n_target_samples,)
1189
+ The class labels for target. If some target samples are unlabelled, fill the
1190
+ :math:`\mathbf{y_t}`'s elements with -1.
1191
+
1192
+ Warning: Note that, due to this convention -1 cannot be used as a
1193
+ class label
1194
+ batch_size : int, optional (default=128)
1195
+ The batch size for out of sample inverse transform
1196
+
1197
+ Returns
1198
+ -------
1199
+ transp_Xs : array-like, shape (n_source_samples, n_features)
1200
+ The transport source samples.
1201
+ """
1202
+ nx = self .nx
1203
+
1204
+ if self .out_of_sample_map == 'ferradans' :
1205
+ return super (SinkhornTransport , self ).transform (Xs , ys , Xt , yt , batch_size )
1206
+
1207
+ else : # self.out_of_sample_map == 'continuous':
1208
+
1209
+ # check the necessary inputs parameters are here
1210
+ g = self .log_ ['log_v' ]
1211
+
1212
+ indices = nx .arange (Xs .shape [0 ])
1213
+ batch_ind = [
1214
+ indices [i :i + batch_size ]
1215
+ for i in range (0 , len (indices ), batch_size )]
1216
+
1217
+ transp_Xs = []
1218
+ for bi in batch_ind :
1219
+ # get the nearest neighbor in the source domain
1220
+ M = dist (Xs [bi ], self .xt_ , metric = self .metric )
1221
+
1222
+ M = cost_normalization (M , self .norm , value = self .norm_cost_ )
1223
+
1224
+ K = nx .exp (- M / self .reg_e + g [None , :])
1225
+
1226
+ transp_Xs_ = nx .dot (K , self .xt_ ) / nx .sum (K , axis = 1 )[:, None ]
1227
+
1228
+ transp_Xs .append (transp_Xs_ )
1229
+
1230
+ transp_Xs = nx .concatenate (transp_Xs , axis = 0 )
1231
+
1232
+ return transp_Xs
1233
+
1234
+ def inverse_transform (self , Xs = None , ys = None , Xt = None , yt = None , batch_size = 128 ):
1235
+ r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
1236
+
1237
+ Parameters
1238
+ ----------
1239
+ Xs : array-like, shape (n_source_samples, n_features)
1240
+ The source input samples.
1241
+ ys : array-like, shape (n_source_samples,)
1242
+ The class labels for source samples
1243
+ Xt : array-like, shape (n_target_samples, n_features)
1244
+ The target input samples.
1245
+ yt : array-like, shape (n_target_samples,)
1246
+ The class labels for target. If some target samples are unlabelled, fill the
1247
+ :math:`\mathbf{y_t}`'s elements with -1.
1248
+
1249
+ Warning: Note that, due to this convention -1 cannot be used as a
1250
+ class label
1251
+ batch_size : int, optional (default=128)
1252
+ The batch size for out of sample inverse transform
1253
+
1254
+ Returns
1255
+ -------
1256
+ transp_Xt : array-like, shape (n_source_samples, n_features)
1257
+ The transport target samples.
1258
+ """
1259
+
1260
+ nx = self .nx
1261
+
1262
+ if self .out_of_sample_map == 'ferradans' :
1263
+ return super (SinkhornTransport , self ).inverse_transform (Xs , ys , Xt , yt , batch_size )
1264
+
1265
+ else : # self.out_of_sample_map == 'continuous':
1266
+
1267
+ f = self .log_ ['log_u' ]
1268
+
1269
+ indices = nx .arange (Xt .shape [0 ])
1270
+ batch_ind = [
1271
+ indices [i :i + batch_size ]
1272
+ for i in range (0 , len (indices ), batch_size
1273
+ )]
1274
+
1275
+ transp_Xt = []
1276
+ for bi in batch_ind :
1277
+
1278
+ M = dist (Xt [bi ], self .xs_ , metric = self .metric )
1279
+ M = cost_normalization (M , self .norm , value = self .norm_cost_ )
1280
+
1281
+ K = nx .exp (- M / self .reg_e + f [None , :])
1282
+
1283
+ transp_Xt_ = nx .dot (K , self .xs_ ) / nx .sum (K , axis = 1 )[:, None ]
1284
+
1285
+ transp_Xt .append (transp_Xt_ )
1286
+
1287
+ transp_Xt = nx .concatenate (transp_Xt , axis = 0 )
1288
+
1289
+ return transp_Xt
1290
+
1153
1291
1154
1292
class EMDTransport (BaseTransport ):
1155
1293
0 commit comments