Skip to content

Commit 402729f

Browse files
author
William de Vazelhes
committed
FIX the threshold by taking the opposite (to be adapted to the decision function)
1 parent dc9e21d commit 402729f

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

metric_learn/base_metric.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def predict(self, pairs):
332332
The predicted learned metric value between samples in every pair.
333333
"""
334334
check_is_fitted(self, ['threshold_', 'transformer_'])
335-
return - 2 * (self.decision_function(pairs) > self.threshold_) + 1
335+
return 2 * (self.decision_function(pairs) > self.threshold_) - 1
336336

337337
def decision_function(self, pairs):
338338
"""Returns the decision function used to classify the pairs.
@@ -387,13 +387,13 @@ def score(self, pairs, y):
387387
return roc_auc_score(y, self.decision_function(pairs))
388388

389389
def set_default_threshold(self, pairs, y):
390-
"""Returns a threshold that is the mean between the similar metrics
391-
mean, and the dissimilar metrics mean"""
392-
similar_threshold = np.mean(self.decision_function(
390+
"""Returns a threshold that is the opposite of the mean between the similar
391+
metrics mean and the dissimilar metrics mean"""
392+
similar_threshold = np.mean(self.score_pairs(
393393
pairs[(y == 1).ravel()]))
394-
dissimilar_threshold = np.mean(self.decision_function(
394+
dissimilar_threshold = np.mean(self.score_pairs(
395395
pairs[(y == -1).ravel()]))
396-
self.threshold_ = np.mean([similar_threshold, dissimilar_threshold])
396+
self.threshold_ = - np.mean([similar_threshold, dissimilar_threshold])
397397

398398

399399
class _QuadrupletsClassifierMixin(BaseMetricLearner):

metric_learn/itml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def fit(self, pairs, y, bounds=None):
187187
Returns the instance.
188188
"""
189189
self._fit(pairs, y, bounds=bounds)
190-
self.threshold_ = np.mean(self.bounds_)
190+
self.threshold_ = - np.mean(self.bounds_)
191191
return self
192192

193193

0 commit comments

Comments
 (0)