File tree Expand file tree Collapse file tree 2 files changed +14
-7
lines changed Expand file tree Collapse file tree 2 files changed +14
-7
lines changed Original file line number Diff line number Diff line change @@ -493,7 +493,7 @@ class label
493
493
494
494
# pairwise distance
495
495
self .cost_ = dist (Xs , Xt , metric = self .metric )
496
- self .cost_ = cost_normalization (self .cost_ , self .norm )
496
+ self .cost_ , self . norm_cost_ = cost_normalization (self .cost_ , self .norm , return_value = True )
497
497
498
498
if (ys is not None ) and (yt is not None ):
499
499
@@ -1209,7 +1209,7 @@ class label
1209
1209
g = self .log_ ['log_v' ]
1210
1210
1211
1211
M = dist (Xs , self .xt_ , metric = self .metric )
1212
- M = cost_normalization (M , self .norm )
1212
+ M = cost_normalization (M , self .norm , value = self . norm_cost_ )
1213
1213
1214
1214
K = nx .exp (- M / self .reg_e + g [None , :])
1215
1215
@@ -1253,7 +1253,7 @@ class label
1253
1253
f = self .log_ ['log_u' ]
1254
1254
1255
1255
M = dist (Xt , self .xs_ , metric = self .metric )
1256
- M = cost_normalization (M , self .norm )
1256
+ M = cost_normalization (M , self .norm , value = self . norm_cost_ )
1257
1257
1258
1258
K = nx .exp (- M / self .reg_e + f [None , :])
1259
1259
Original file line number Diff line number Diff line change @@ -360,7 +360,7 @@ def dist0(n, method='lin_square'):
360
360
return res
361
361
362
362
363
- def cost_normalization (C , norm = None ):
363
+ def cost_normalization (C , norm = None , return_value = False , value = None ):
364
364
r""" Apply normalization to the loss matrix
365
365
366
366
Parameters
@@ -382,9 +382,13 @@ def cost_normalization(C, norm=None):
382
382
if norm is None :
383
383
pass
384
384
elif norm == "median" :
385
- C /= float (nx .median (C ))
385
+ if value is None :
386
+ value = nx .median (C )
387
+ C /= value
386
388
elif norm == "max" :
387
- C /= float (nx .max (C ))
389
+ if value is None :
390
+ value = nx .max (C )
391
+ C /= float (value )
388
392
elif norm == "log" :
389
393
C = nx .log (1 + C )
390
394
elif norm == "loglog" :
@@ -393,7 +397,10 @@ def cost_normalization(C, norm=None):
393
397
raise ValueError ('Norm %s is not a valid option.\n '
394
398
'Valid options are:\n '
395
399
'median, max, log, loglog' % norm )
396
- return C
400
+ if return_value :
401
+ return C , value
402
+ else :
403
+ return C
397
404
398
405
399
406
def dots (* args ):
You can’t perform that action at this time.
0 commit comments