From 7f1373c491da6e6c2d629aa96b9f113271db8c97 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 21 Mar 2019 11:20:36 +0100 Subject: [PATCH] Add classes attribute and test for CalibratedClassifierCV --- metric_learn/base_metric.py | 9 +++++++++ metric_learn/itml.py | 5 +++++ metric_learn/mmc.py | 11 ++++++++++- metric_learn/sdml.py | 5 +++++ test/test_sklearn_compat.py | 29 ++++++++++++++++++++++++++++- 5 files changed, 57 insertions(+), 2 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 58b8cc5d..47ac4751 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -295,7 +295,16 @@ def get_mahalanobis_matrix(self): class _PairsClassifierMixin(BaseMetricLearner): + """ + Attributes + ---------- + classes_ : `list` + The possible labels of the pairs the metric learner can fit on. + `classes_ = [-1, 1]`, where -1 means points in a pair are dissimilar + (negative label), and 1 means they are similar (positive label). + """ + classes_ = [-1, 1] _tuple_size = 2 # number of points in a tuple, 2 for pairs def predict(self, pairs): diff --git a/metric_learn/itml.py b/metric_learn/itml.py index a0ff05f9..c6ced9c9 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -148,6 +148,11 @@ class ITML(_BaseITML, _PairsClassifierMixin): transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) + + classes_ : `list` + The possible labels of the pairs `ITML` can fit on. `classes_ = [-1, 1]`, + where -1 means points in a pair are dissimilar (negative label), and 1 + means they are similar (positive label). """ def fit(self, pairs, y, bounds=None): diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index f9d3690b..41b6f218 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -29,7 +29,16 @@ class _BaseMMC(MahalanobisMixin): - """Mahalanobis Metric for Clustering (MMC)""" + """Mahalanobis Metric for Clustering (MMC) + + Attributes + ---------- + + classes_ : `list` + The possible labels of the pairs `MMC` can fit on. `classes_ = [-1, 1]`, + where -1 means points in a pair are dissimilar (negative label), and 1 + means they are similar (positive label). + """ _tuple_size = 2 # constraints are pairs diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 78fc4ebc..abb170f9 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -81,6 +81,11 @@ class SDML(_BaseSDML, _PairsClassifierMixin): transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) + + classes_ : `list` + The possible labels of the pairs `SDML` can fit on. `classes_ = [-1, 1]`, + where -1 means points in a pair are dissimilar (negative label), and 1 + means they are similar (positive label). """ def fit(self, pairs, y): diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index d9dce685..65a32282 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -1,5 +1,6 @@ import pytest import unittest +from sklearn.calibration import CalibratedClassifierCV from sklearn.utils.estimator_checks import check_estimator from sklearn.base import TransformerMixin from sklearn.pipeline import make_pipeline @@ -17,7 +18,8 @@ train_test_split, KFold) from sklearn.utils.testing import _get_args from test.test_utils import (metric_learners, ids_metric_learners, - mock_preprocessor) + mock_preprocessor, pairs_learners, + ids_pairs_learners) # Wrap the _Supervised methods with a deterministic wrapper for testing. @@ -87,6 +89,31 @@ def test_mmc(self): # ---------------------- Test scikit-learn compatibility ---------------------- +@pytest.mark.parametrize('with_preprocessor', + [True, + # TODO: uncomment the below line as soon as + # https://github.com/scikit-learn/scikit-learn/ + # issues/13077 is solved: + # False, + ]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_calibrated_classifier_CV(estimator, build_dataset, + with_preprocessor): + """Tests that metric-learn tuples estimators' work with scikit-learn's + CalibratedClassifierCV. + """ + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + calibrated_clf = CalibratedClassifierCV(estimator) + + # test fit and predict_proba + calibrated_clf.fit(input_data, labels) + calibrated_clf.predict_proba(input_data) + + @pytest.mark.parametrize('with_preprocessor', [True, False]) @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners)