Skip to content

Commit 8c3cb3e

Browse files
wdevazelhesperimosocordiae
authored andcommitted
[MRG] Fix quadruplets scoring (#220)
* FIX: fix lsml scoring * Address #220 (review)
1 parent 8518517 commit 8c3cb3e

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

metric_learn/base_metric.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,4 +641,9 @@ def score(self, quadruplets):
641641
score : float
642642
The quadruplets score.
643643
"""
644-
return - np.mean(self.predict(quadruplets))
644+
# Since the prediction is a vector of values in {-1, +1}, we need to
645+
# rescale them to {0, 1} to compute the accuracy using the mean (because
646+
# then 1 means a correctly classified result (pairs are in the right
647+
# order), and a 0 an incorrectly classified result (pairs are in the
648+
# wrong order).
649+
return self.predict(quadruplets).mean() / 2 + 0.5

test/test_pairs_classifiers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from numpy.testing import assert_array_equal
7+
from scipy.spatial.distance import euclidean
78

89
from metric_learn.base_metric import _PairsClassifierMixin, MahalanobisMixin
910
from sklearn.exceptions import NotFittedError
@@ -489,3 +490,31 @@ def breaking_fun(**args): # a function that fails so that we will miss
489490
with pytest.raises(ValueError) as raised_error:
490491
estimator.fit(input_data, labels, calibration_params={'strategy': 'weird'})
491492
assert str(raised_error.value) == expected_msg
493+
494+
495+
@pytest.mark.parametrize('estimator, build_dataset', pairs_learners,
496+
ids=ids_pairs_learners)
497+
def test_accuracy_toy_example(estimator, build_dataset):
498+
"""Test that the accuracy works on some toy example (hence that the
499+
prediction is OK)"""
500+
input_data, labels, preprocessor, X = build_dataset(with_preprocessor=False)
501+
estimator = clone(estimator)
502+
estimator.set_params(preprocessor=preprocessor)
503+
set_random_state(estimator)
504+
estimator.fit(input_data, labels)
505+
# we force the transformation to be identity so that we control what it does
506+
estimator.transformer_ = np.eye(X.shape[1])
507+
# the threshold for similar or dissimilar pairs is half of the distance
508+
# between X[0] and X[1]
509+
estimator.set_threshold(euclidean(X[0], X[1]) / 2)
510+
# We take the two first points and we build 4 regularly spaced points on the
511+
# line they define, so that it's easy to build quadruplets of different
512+
# similarities.
513+
X_test = X[0] + np.arange(4)[:, np.newaxis] * (X[0] - X[1]) / 4
514+
pairs_test = np.array(
515+
[[X_test[0], X_test[1]], # similar
516+
[X_test[0], X_test[3]], # dissimilar
517+
[X_test[1], X_test[2]], # similar
518+
[X_test[2], X_test[3]]]) # similar
519+
y = np.array([-1, 1, 1, -1]) # [F, F, T, F]
520+
assert accuracy_score(estimator.predict(pairs_test), y) == 0.25

test/test_quadruplets_classifiers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,26 @@ def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset,
4040
with pytest.raises(NotFittedError):
4141
estimator.predict(input_data)
4242

43+
44+
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners,
45+
ids=ids_quadruplets_learners)
46+
def test_accuracy_toy_example(estimator, build_dataset):
47+
"""Test that the default scoring for quadruplets (accuracy) works on some
48+
toy example"""
49+
input_data, labels, preprocessor, X = build_dataset(with_preprocessor=False)
50+
estimator = clone(estimator)
51+
estimator.set_params(preprocessor=preprocessor)
52+
set_random_state(estimator)
53+
estimator.fit(input_data)
54+
# We take the two first points and we build 4 regularly spaced points on the
55+
# line they define, so that it's easy to build quadruplets of different
56+
# similarities.
57+
X_test = X[0] + np.arange(4)[:, np.newaxis] * (X[0] - X[1]) / 4
58+
quadruplets_test = np.array(
59+
[[X_test[0], X_test[2], X_test[0], X_test[1]],
60+
[X_test[1], X_test[3], X_test[1], X_test[0]],
61+
[X_test[1], X_test[2], X_test[0], X_test[3]],
62+
[X_test[3], X_test[0], X_test[2], X_test[1]]])
63+
# we force the transformation to be identity so that we control what it does
64+
estimator.transformer_ = np.eye(X.shape[1])
65+
assert estimator.score(quadruplets_test) == 0.25

0 commit comments

Comments
 (0)