Skip to content

Commit c356874

Browse files
committed
have corect normalization for entropic mapping
1 parent a7ce443 commit c356874

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

ot/da.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ class label
493493

494494
# pairwise distance
495495
self.cost_ = dist(Xs, Xt, metric=self.metric)
496-
self.cost_ = cost_normalization(self.cost_, self.norm)
496+
self.cost_, self.norm_cost_ = cost_normalization(self.cost_, self.norm, return_value=True)
497497

498498
if (ys is not None) and (yt is not None):
499499

@@ -1209,7 +1209,7 @@ class label
12091209
g = self.log_['log_v']
12101210

12111211
M = dist(Xs, self.xt_, metric=self.metric)
1212-
M = cost_normalization(M, self.norm)
1212+
M = cost_normalization(M, self.norm, value=self.norm_cost_)
12131213

12141214
K = nx.exp(-M / self.reg_e + g[None, :])
12151215

@@ -1253,7 +1253,7 @@ class label
12531253
f = self.log_['log_u']
12541254

12551255
M = dist(Xt, self.xs_, metric=self.metric)
1256-
M = cost_normalization(M, self.norm)
1256+
M = cost_normalization(M, self.norm, value=self.norm_cost_)
12571257

12581258
K = nx.exp(-M / self.reg_e + f[None, :])
12591259

ot/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def dist0(n, method='lin_square'):
360360
return res
361361

362362

363-
def cost_normalization(C, norm=None):
363+
def cost_normalization(C, norm=None, return_value=False, value=None):
364364
r""" Apply normalization to the loss matrix
365365
366366
Parameters
@@ -382,9 +382,13 @@ def cost_normalization(C, norm=None):
382382
if norm is None:
383383
pass
384384
elif norm == "median":
385-
C /= float(nx.median(C))
385+
if value is None:
386+
value = nx.median(C)
387+
C /= value
386388
elif norm == "max":
387-
C /= float(nx.max(C))
389+
if value is None:
390+
value = nx.max(C)
391+
C /= float(value)
388392
elif norm == "log":
389393
C = nx.log(1 + C)
390394
elif norm == "loglog":
@@ -393,7 +397,10 @@ def cost_normalization(C, norm=None):
393397
raise ValueError('Norm %s is not a valid option.\n'
394398
'Valid options are:\n'
395399
'median, max, log, loglog' % norm)
396-
return C
400+
if return_value:
401+
return C, value
402+
else:
403+
return C
397404

398405

399406
def dots(*args):

0 commit comments

Comments
 (0)