Skip to content

Commit 255789c

Browse files
committed
implement batches
1 parent c356874 commit 255789c

File tree

1 file changed

+34
-8
lines changed

1 file changed

+34
-8
lines changed

ot/da.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,12 +1208,25 @@ class label
12081208
# check the necessary inputs parameters are here
12091209
g = self.log_['log_v']
12101210

1211-
M = dist(Xs, self.xt_, metric=self.metric)
1212-
M = cost_normalization(M, self.norm, value=self.norm_cost_)
1211+
indices = nx.arange(Xs.shape[0])
1212+
batch_ind = [
1213+
indices[i:i + batch_size]
1214+
for i in range(0, len(indices), batch_size)]
12131215

1214-
K = nx.exp(-M / self.reg_e + g[None, :])
1216+
transp_Xs = []
1217+
for bi in batch_ind:
1218+
# get the nearest neighbor in the source domain
1219+
M = dist(Xs[bi], self.xt_, metric=self.metric)
12151220

1216-
transp_Xs = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None]
1221+
M = cost_normalization(M, self.norm, value=self.norm_cost_)
1222+
1223+
K = nx.exp(-M / self.reg_e + g[None, :])
1224+
1225+
transp_Xs_ = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None]
1226+
1227+
transp_Xs.append(transp_Xs_)
1228+
1229+
transp_Xs = nx.concatenate(transp_Xs, axis=0)
12171230

12181231
return transp_Xs
12191232

@@ -1252,12 +1265,25 @@ class label
12521265

12531266
f = self.log_['log_u']
12541267

1255-
M = dist(Xt, self.xs_, metric=self.metric)
1256-
M = cost_normalization(M, self.norm, value=self.norm_cost_)
1268+
indices = nx.arange(Xt.shape[0])
1269+
batch_ind = [
1270+
indices[i:i + batch_size]
1271+
for i in range(0, len(indices), batch_size
1272+
)]
1273+
1274+
transp_Xt = []
1275+
for bi in batch_ind:
1276+
1277+
M = dist(Xt[bi], self.xs_, metric=self.metric)
1278+
M = cost_normalization(M, self.norm, value=self.norm_cost_)
1279+
1280+
K = nx.exp(-M / self.reg_e + f[None, :])
1281+
1282+
transp_Xt_ = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None]
12571283

1258-
K = nx.exp(-M / self.reg_e + f[None, :])
1284+
transp_Xt.append(transp_Xt_)
12591285

1260-
transp_Xt = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None]
1286+
transp_Xt = nx.concatenate(transp_Xt, axis=0)
12611287

12621288
return transp_Xt
12631289

0 commit comments

Comments
 (0)