@@ -332,7 +332,7 @@ def predict(self, pairs):
332
332
The predicted learned metric value between samples in every pair.
333
333
"""
334
334
check_is_fitted (self , ['threshold_' , 'transformer_' ])
335
- return - 2 * (self .decision_function (pairs ) > self .threshold_ ) + 1
335
+ return 2 * (self .decision_function (pairs ) > self .threshold_ ) - 1
336
336
337
337
def decision_function (self , pairs ):
338
338
"""Returns the decision function used to classify the pairs.
@@ -387,13 +387,13 @@ def score(self, pairs, y):
387
387
return roc_auc_score (y , self .decision_function (pairs ))
388
388
389
389
def set_default_threshold (self , pairs , y ):
390
- """Returns a threshold that is the mean between the similar metrics
391
- mean, and the dissimilar metrics mean"""
392
- similar_threshold = np .mean (self .decision_function (
390
+ """Returns a threshold that is the opposite of the mean between the similar
391
+ metrics mean and the dissimilar metrics mean"""
392
+ similar_threshold = np .mean (self .score_pairs (
393
393
pairs [(y == 1 ).ravel ()]))
394
- dissimilar_threshold = np .mean (self .decision_function (
394
+ dissimilar_threshold = np .mean (self .score_pairs (
395
395
pairs [(y == - 1 ).ravel ()]))
396
- self .threshold_ = np .mean ([similar_threshold , dissimilar_threshold ])
396
+ self .threshold_ = - np .mean ([similar_threshold , dissimilar_threshold ])
397
397
398
398
399
399
class _QuadrupletsClassifierMixin (BaseMetricLearner ):
0 commit comments