@@ -1208,12 +1208,25 @@ class label
1208
1208
# check the necessary inputs parameters are here
1209
1209
g = self .log_ ['log_v' ]
1210
1210
1211
- M = dist (Xs , self .xt_ , metric = self .metric )
1212
- M = cost_normalization (M , self .norm , value = self .norm_cost_ )
1211
+ indices = nx .arange (Xs .shape [0 ])
1212
+ batch_ind = [
1213
+ indices [i :i + batch_size ]
1214
+ for i in range (0 , len (indices ), batch_size )]
1213
1215
1214
- K = nx .exp (- M / self .reg_e + g [None , :])
1216
+ transp_Xs = []
1217
+ for bi in batch_ind :
1218
+ # get the nearest neighbor in the source domain
1219
+ M = dist (Xs [bi ], self .xt_ , metric = self .metric )
1215
1220
1216
- transp_Xs = nx .dot (K , self .xt_ ) / nx .sum (K , axis = 1 )[:, None ]
1221
+ M = cost_normalization (M , self .norm , value = self .norm_cost_ )
1222
+
1223
+ K = nx .exp (- M / self .reg_e + g [None , :])
1224
+
1225
+ transp_Xs_ = nx .dot (K , self .xt_ ) / nx .sum (K , axis = 1 )[:, None ]
1226
+
1227
+ transp_Xs .append (transp_Xs_ )
1228
+
1229
+ transp_Xs = nx .concatenate (transp_Xs , axis = 0 )
1217
1230
1218
1231
return transp_Xs
1219
1232
@@ -1252,12 +1265,25 @@ class label
1252
1265
1253
1266
f = self .log_ ['log_u' ]
1254
1267
1255
- M = dist (Xt , self .xs_ , metric = self .metric )
1256
- M = cost_normalization (M , self .norm , value = self .norm_cost_ )
1268
+ indices = nx .arange (Xt .shape [0 ])
1269
+ batch_ind = [
1270
+ indices [i :i + batch_size ]
1271
+ for i in range (0 , len (indices ), batch_size
1272
+ )]
1273
+
1274
+ transp_Xt = []
1275
+ for bi in batch_ind :
1276
+
1277
+ M = dist (Xt [bi ], self .xs_ , metric = self .metric )
1278
+ M = cost_normalization (M , self .norm , value = self .norm_cost_ )
1279
+
1280
+ K = nx .exp (- M / self .reg_e + f [None , :])
1281
+
1282
+ transp_Xt_ = nx .dot (K , self .xs_ ) / nx .sum (K , axis = 1 )[:, None ]
1257
1283
1258
- K = nx . exp ( - M / self . reg_e + f [ None , :] )
1284
+ transp_Xt . append ( transp_Xt_ )
1259
1285
1260
- transp_Xt = nx .dot ( K , self . xs_ ) / nx . sum ( K , axis = 1 )[:, None ]
1286
+ transp_Xt = nx .concatenate ( transp_Xt , axis = 0 )
1261
1287
1262
1288
return transp_Xt
1263
1289
0 commit comments