|
4 | 4 |
|
5 | 5 | import pytest
|
6 | 6 | from numpy.testing import assert_array_equal
|
| 7 | +from scipy.spatial.distance import euclidean |
7 | 8 |
|
8 | 9 | from metric_learn.base_metric import _PairsClassifierMixin, MahalanobisMixin
|
9 | 10 | from sklearn.exceptions import NotFittedError
|
@@ -489,3 +490,31 @@ def breaking_fun(**args): # a function that fails so that we will miss
|
489 | 490 | with pytest.raises(ValueError) as raised_error:
|
490 | 491 | estimator.fit(input_data, labels, calibration_params={'strategy': 'weird'})
|
491 | 492 | assert str(raised_error.value) == expected_msg
|
| 493 | + |
| 494 | + |
| 495 | +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, |
| 496 | + ids=ids_pairs_learners) |
| 497 | +def test_accuracy_toy_example(estimator, build_dataset): |
| 498 | + """Test that the accuracy works on some toy example (hence that the |
| 499 | + prediction is OK)""" |
| 500 | + input_data, labels, preprocessor, X = build_dataset(with_preprocessor=False) |
| 501 | + estimator = clone(estimator) |
| 502 | + estimator.set_params(preprocessor=preprocessor) |
| 503 | + set_random_state(estimator) |
| 504 | + estimator.fit(input_data, labels) |
| 505 | + # we force the transformation to be identity so that we control what it does |
| 506 | + estimator.transformer_ = np.eye(X.shape[1]) |
| 507 | + # the threshold for similar or dissimilar pairs is half of the distance |
| 508 | + # between X[0] and X[1] |
| 509 | + estimator.set_threshold(euclidean(X[0], X[1]) / 2) |
| 510 | + # We take the two first points and we build 4 regularly spaced points on the |
| 511 | + # line they define, so that it's easy to build quadruplets of different |
| 512 | + # similarities. |
| 513 | + X_test = X[0] + np.arange(4)[:, np.newaxis] * (X[0] - X[1]) / 4 |
| 514 | + pairs_test = np.array( |
| 515 | + [[X_test[0], X_test[1]], # similar |
| 516 | + [X_test[0], X_test[3]], # dissimilar |
| 517 | + [X_test[1], X_test[2]], # similar |
| 518 | + [X_test[2], X_test[3]]]) # similar |
| 519 | + y = np.array([-1, 1, 1, -1]) # [F, F, T, F] |
| 520 | + assert accuracy_score(estimator.predict(pairs_test), y) == 0.25 |
0 commit comments