diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 02519de1..889de999 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -1,9 +1,11 @@ -from numpy.linalg import inv, cholesky +from numpy.linalg import cholesky from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils.validation import check_array +from sklearn.metrics import roc_auc_score +import numpy as np -class BaseMetricLearner(BaseEstimator, TransformerMixin): +class BaseMetricLearner(BaseEstimator): def __init__(self): raise NotImplementedError('BaseMetricLearner should not be instantiated') @@ -30,6 +32,9 @@ def transformer(self): """ return cholesky(self.metric()).T + +class MetricTransformer(TransformerMixin): + def transform(self, X=None): """Applies the metric transformation. @@ -49,3 +54,104 @@ def transform(self, X=None): X = check_array(X, accept_sparse=True) L = self.transformer() return X.dot(L.T) + + +class _PairsClassifierMixin: + + def predict(self, pairs): + """Predicts the learned metric between input pairs. + + Returns the learned metric value between samples in every pair. It should + ideally be low for similar samples and high for dissimilar samples. + + Parameters + ---------- + pairs : array-like, shape=(n_constraints, 2, n_features) + Input pairs. + + Returns + ------- + y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) + The predicted learned metric value between samples in every pair. + """ + pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :] + return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs, + axis=1)) + + def decision_function(self, pairs): + return self.predict(pairs) + + def score(self, pairs, y): + """Computes score of pairs similarity prediction. + + Returns the ``roc_auc`` score of the fitted metric learner. It is + computed in the following way: for every value of a threshold + ``t`` we classify all pairs of samples where the predicted distance is + inferior to ``t`` as belonging to the "similar" class, and the other as + belonging to the "dissimilar" class, and we count false positive and + true positives as in a classical ``roc_auc`` curve. + + Parameters + ---------- + pairs : array-like, shape=(n_constraints, 2, n_features) + Input Pairs. + + y : array-like, shape=(n_constraints,) + The corresponding labels. + + Returns + ------- + score : float + The ``roc_auc`` score. + """ + return roc_auc_score(y, self.decision_function(pairs)) + + +class _QuadrupletsClassifierMixin: + + def predict(self, quadruplets): + """Predicts differences between sample distances in input quadruplets. + + For each quadruplet of samples, computes the difference between the learned + metric of the first pair minus the learned metric of the second pair. + + Parameters + ---------- + quadruplets : array-like, shape=(n_constraints, 4, n_features) + Input quadruplets. + + Returns + ------- + prediction : `numpy.ndarray` of floats, shape=(n_constraints,) + Metric differences. + """ + similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :] + dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :] + return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) * + similar_diffs, axis=1)) - + np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) * + dissimilar_diffs, axis=1))) + + def decision_function(self, quadruplets): + return self.predict(quadruplets) + + def score(self, quadruplets, y=None): + """Computes score on input quadruplets + + Returns the accuracy score of the following classification task: a record + is correctly classified if the predicted similarity between the first two + samples is higher than that of the last two. + + Parameters + ---------- + quadruplets : array-like, shape=(n_constraints, 4, n_features) + Input quadruplets. + + y : Ignored, for scikit-learn compatibility. + + Returns + ------- + score : float + The quadruplets score. + """ + return - np.mean(np.sign(self.decision_function(quadruplets))) diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 8fc07873..689650b4 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -12,10 +12,10 @@ import numpy as np from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer -class Covariance(BaseMetricLearner): +class Covariance(BaseMetricLearner, MetricTransformer): def __init__(self): pass diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 4d719591..fc839611 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -19,12 +19,13 @@ from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, + MetricTransformer) from .constraints import Constraints, wrap_pairs from ._util import vector_norm -class ITML(BaseMetricLearner): +class _BaseITML(BaseMetricLearner): """Information Theoretic Metric Learning (ITML)""" def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, A0=None, verbose=False): @@ -78,24 +79,7 @@ def _process_pairs(self, pairs, y, bounds): y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))]) return pairs, y - - def fit(self, pairs, y, bounds=None): - """Learn the ITML model. - - Parameters - ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) - Array of pairs. Each row corresponds to two points. - y: array-like, of shape (n_constraints,) - Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - bounds : list (pos,neg) pairs, optional - bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg - - Returns - ------- - self : object - Returns the instance. - """ + def _fit(self, pairs, y, bounds=None): pairs, y = self._process_pairs(pairs, y, bounds) gamma = self.gamma pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] @@ -151,7 +135,29 @@ def metric(self): return self.A_ -class ITML_Supervised(ITML): +class ITML(_BaseITML, _PairsClassifierMixin): + + def fit(self, pairs, y, bounds=None): + """Learn the ITML model. + + Parameters + ---------- + pairs: array-like, shape=(n_constraints, 2, n_features) + Array of pairs. Each row corresponds to two points. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. + bounds : list (pos,neg) pairs, optional + bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg + + Returns + ------- + self : object + Returns the instance. + """ + return self._fit(pairs, y, bounds=bounds) + + +class ITML_Supervised(_BaseITML, MetricTransformer): """Information Theoretic Metric Learning (ITML)""" def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, num_labeled=np.inf, num_constraints=None, bounds=None, A0=None, @@ -175,9 +181,9 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, verbose : bool, optional if True, prints information while learning """ - ITML.__init__(self, gamma=gamma, max_iter=max_iter, - convergence_threshold=convergence_threshold, - A0=A0, verbose=verbose) + _BaseITML.__init__(self, gamma=gamma, max_iter=max_iter, + convergence_threshold=convergence_threshold, + A0=A0, verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints self.bounds = bounds @@ -207,4 +213,4 @@ def fit(self, X, y, random_state=np.random): pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) - return ITML.fit(self, pairs, y, bounds=self.bounds) + return _BaseITML._fit(self, pairs, y, bounds=self.bounds) diff --git a/metric_learn/lfda.py b/metric_learn/lfda.py index 809f092b..03df5f24 100644 --- a/metric_learn/lfda.py +++ b/metric_learn/lfda.py @@ -18,10 +18,10 @@ from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer -class LFDA(BaseMetricLearner): +class LFDA(BaseMetricLearner, MetricTransformer): ''' Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction Sugiyama, ICML 2006 diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index dea12f0c..581dc72a 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -17,11 +17,11 @@ from sklearn.utils.validation import check_X_y, check_array from sklearn.metrics import euclidean_distances -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer # commonality between LMNN implementations -class _base_LMNN(BaseMetricLearner): +class _base_LMNN(BaseMetricLearner, MetricTransformer): def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, use_pca=True, verbose=False): diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index b8b69f19..cdbc75d5 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -13,11 +13,12 @@ from six.moves import xrange from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner -from .constraints import Constraints, wrap_pairs +from .base_metric import (BaseMetricLearner, _QuadrupletsClassifierMixin, + MetricTransformer) +from .constraints import Constraints -class LSML(BaseMetricLearner): +class _BaseLSML(BaseMetricLearner): def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): """Initialize LSML. @@ -60,24 +61,7 @@ def _prepare_quadruplets(self, quadruplets, weights): def metric(self): return self.M_ - def fit(self, quadruplets, weights=None): - """Learn the LSML model. - - Parameters - ---------- - quadruplets : array-like, shape=(n_constraints, 4, n_features) - Each row corresponds to 4 points. In order to supervise the - algorithm in the right way, we should have the four samples ordered - in a way such that: d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) - for all 0 <= i < n_constraints. - weights : (n_constraints,) array of floats, optional - scale factor for each constraint - - Returns - ------- - self : object - Returns the instance. - """ + def _fit(self, quadruplets, weights=None): self._prepare_quadruplets(quadruplets, weights) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. @@ -140,7 +124,30 @@ def _gradient(self, metric): return dMetric -class LSML_Supervised(LSML): +class LSML(_BaseLSML, _QuadrupletsClassifierMixin): + + def fit(self, quadruplets, weights=None): + """Learn the LSML model. + + Parameters + ---------- + quadruplets : array-like, shape=(n_constraints, 4, n_features) + Each row corresponds to 4 points. In order to supervise the + algorithm in the right way, we should have the four samples ordered + in a way such that: d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) + for all 0 <= i < n_constraints. + weights : (n_constraints,) array of floats, optional + scale factor for each constraint + + Returns + ------- + self : object + Returns the instance. + """ + return self._fit(quadruplets, weights=weights) + + +class LSML_Supervised(_BaseLSML, MetricTransformer): def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, num_constraints=None, weights=None, verbose=False): """Initialize the learner. @@ -160,8 +167,8 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, verbose : bool, optional if True, prints information while learning """ - LSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, - verbose=verbose) + _BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, + verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints self.weights = weights @@ -189,5 +196,6 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, same_length=True, - random_state=random_state) - return LSML.fit(self, X[np.column_stack(pos_neg)], weights=self.weights) + random_state=random_state) + return _BaseLSML._fit(self, X[np.column_stack(pos_neg)], + weights=self.weights) diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 35b80495..a16c40aa 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -13,12 +13,12 @@ from sklearn.decomposition import PCA from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer EPS = np.finfo(float).eps -class MLKR(BaseMetricLearner): +class MLKR(BaseMetricLearner, MetricTransformer): """Metric Learning for Kernel Regression (MLKR)""" def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001, max_iter=1000): diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index a72fa14b..f61bb1c7 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -19,16 +19,15 @@ from __future__ import print_function, absolute_import, division import numpy as np from six.moves import xrange -from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, + MetricTransformer) from .constraints import Constraints, wrap_pairs from ._util import vector_norm - -class MMC(BaseMetricLearner): +class _BaseMMC(BaseMetricLearner): """Mahalanobis Metric for Clustering (MMC)""" def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, A0=None, diagonal=False, diagonal_c=1.0, verbose=False): @@ -58,22 +57,7 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, self.diagonal_c = diagonal_c self.verbose = verbose - - def fit(self, pairs, y): - """Learn the MMC model. - - Parameters - ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) - Array of pairs. Each row corresponds to two points. - y: array-like, of shape (n_constraints,) - Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - - Returns - ------- - self : object - Returns the instance. - """ + def _fit(self, pairs, y): pairs, y = self._process_pairs(pairs, y) if self.diagonal: return self._fit_diag(pairs, y) @@ -388,7 +372,27 @@ def transformer(self): return V.T * np.sqrt(np.maximum(0, w[:,None])) -class MMC_Supervised(MMC): +class MMC(_BaseMMC, _PairsClassifierMixin): + + def fit(self, pairs, y): + """Learn the MMC model. + + Parameters + ---------- + pairs: array-like, shape=(n_constraints, 2, n_features) + Array of pairs. Each row corresponds to two points. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. + + Returns + ------- + self : object + Returns the instance. + """ + return self._fit(pairs, y) + + +class MMC_Supervised(_BaseMMC, MetricTransformer): """Mahalanobis Metric for Clustering (MMC)""" def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, num_labeled=np.inf, num_constraints=None, @@ -416,10 +420,10 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, verbose : bool, optional if True, prints information while learning """ - MMC.__init__(self, max_iter=max_iter, max_proj=max_proj, - convergence_threshold=convergence_threshold, - A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, - verbose=verbose) + _BaseMMC.__init__(self, max_iter=max_iter, max_proj=max_proj, + convergence_threshold=convergence_threshold, + A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, + verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -446,4 +450,4 @@ def fit(self, X, y, random_state=np.random): pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) - return MMC.fit(self, pairs, y) + return _BaseMMC._fit(self, pairs, y) diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 40757d23..9a6af0c3 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -8,12 +8,12 @@ from six.moves import xrange from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer EPS = np.finfo(float).eps -class NCA(BaseMetricLearner): +class NCA(BaseMetricLearner, MetricTransformer): def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01): self.num_dims = num_dims self.max_iter = max_iter diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 0d9b3620..36dd0aae 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -18,7 +18,7 @@ from sklearn import decomposition from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer from .constraints import Constraints @@ -35,7 +35,7 @@ def _chunk_mean_centering(data, chunks): return chunk_mask, chunk_data -class RCA(BaseMetricLearner): +class RCA(BaseMetricLearner, MetricTransformer): """Relevant Components Analysis (RCA)""" def __init__(self, num_dims=None, pca_comps=None): """Initialize the learner. diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 19919ab1..2e40ad91 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -15,11 +15,12 @@ from sklearn.utils.extmath import pinvh from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, + MetricTransformer) from .constraints import Constraints, wrap_pairs -class SDML(BaseMetricLearner): +class _BaseSDML(BaseMetricLearner): def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose=False): """ @@ -57,6 +58,18 @@ def _prepare_pairs(self, pairs, y): def metric(self): return self.M_ + def _fit(self, pairs, y): + loss_matrix = self._prepare_pairs(pairs, y) + P = self.M_ + self.balance_param * loss_matrix + emp_cov = pinvh(P) + # hack: ensure positive semidefinite + emp_cov = emp_cov.T.dot(emp_cov) + _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) + return self + + +class SDML(_BaseSDML, _PairsClassifierMixin): + def fit(self, pairs, y): """Learn the SDML model. @@ -72,16 +85,10 @@ def fit(self, pairs, y): self : object Returns the instance. """ - loss_matrix = self._prepare_pairs(pairs, y) - P = self.M_ + self.balance_param * loss_matrix - emp_cov = pinvh(P) - # hack: ensure positive semidefinite - emp_cov = emp_cov.T.dot(emp_cov) - _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) - return self + return self._fit(pairs, y) -class SDML_Supervised(SDML): +class SDML_Supervised(_BaseSDML, MetricTransformer): def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, num_labeled=np.inf, num_constraints=None, verbose=False): """ @@ -100,9 +107,9 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose : bool, optional if True, prints information while learning """ - SDML.__init__(self, balance_param=balance_param, - sparsity_param=sparsity_param, use_cov=use_cov, - verbose=verbose) + _BaseSDML.__init__(self, balance_param=balance_param, + sparsity_param=sparsity_param, use_cov=use_cov, + verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -133,6 +140,6 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, - random_state=random_state) + random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) - return SDML.fit(self, pairs, y) + return _BaseSDML._fit(self, pairs, y) diff --git a/test/test_weakly_supervised.py b/test/test_weakly_supervised.py new file mode 100644 index 00000000..6386d22a --- /dev/null +++ b/test/test_weakly_supervised.py @@ -0,0 +1,214 @@ +import pytest +from sklearn.datasets import load_iris +from sklearn.pipeline import make_pipeline +from sklearn.utils import shuffle, check_random_state +from sklearn.utils.estimator_checks import is_public_parameter +from sklearn.utils.testing import (assert_allclose_dense_sparse, + set_random_state) + +from metric_learn import ITML, MMC, SDML, LSML +from metric_learn.constraints import wrap_pairs, Constraints +from sklearn import clone +import numpy as np +from sklearn.model_selection import cross_val_score, train_test_split + + +def build_data(): + RNG = check_random_state(0) + dataset = load_iris() + X, y = shuffle(dataset.data, dataset.target, random_state=RNG) + num_constraints = 20 + constraints = Constraints.random_subset(y) + pairs = constraints.positive_negative_pairs(num_constraints, + same_length=True, + random_state=RNG) + return X, pairs + + +def build_pairs(): + # test that you can do cross validation on tuples of points with + # a WeaklySupervisedMetricLearner + X, pairs = build_data() + pairs, y = wrap_pairs(X, pairs) + pairs, y = shuffle(pairs, y) + (pairs_train, pairs_test, y_train, + y_test) = train_test_split(pairs, y) + return (pairs, y, pairs_train, pairs_test, + y_train, y_test) + + +def build_quadruplets(): + # test that you can do cross validation on a tuples of points with + # a WeaklySupervisedMetricLearner + X, pairs = build_data() + c = np.column_stack(pairs) + quadruplets = X[c] + quadruplets = shuffle(quadruplets) + y = y_train = y_test = None + quadruplets_train, quadruplets_test = train_test_split(quadruplets) + return (quadruplets, y, quadruplets_train, quadruplets_test, + y_train, y_test) + + +list_estimators = [(ITML(), build_pairs), + (LSML(), build_quadruplets), + (MMC(), build_pairs), + (SDML(), build_pairs) + ] + +ids_estimators = ['itml', + 'lsml', + 'mmc', + 'sdml', + ] + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_cross_validation(estimator, build_dataset): + (tuples, y, tuples_train, tuples_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + set_random_state(estimator) + + assert np.isfinite(cross_val_score(estimator, tuples, y)).all() + + +def check_score(estimator, tuples, y): + score = estimator.score(tuples, y) + assert np.isfinite(score) + + +def check_predict(estimator, tuples): + y_predicted = estimator.predict(tuples) + assert len(y_predicted), len(tuples) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_simple_estimator(estimator, build_dataset): + (tuples, y, tuples_train, tuples_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + set_random_state(estimator) + + estimator.fit(tuples_train, y_train) + check_score(estimator, tuples_test, y_test) + check_predict(estimator, tuples_test) + + +@pytest.mark.parametrize('estimator', [est[0] for est in list_estimators], + ids=ids_estimators) +def test_no_fit_attributes_set_in_init(estimator): + """Check that Estimator.__init__ doesn't set trailing-_ attributes.""" + # From scikit-learn + estimator = clone(estimator) + for attr in dir(estimator): + if attr.endswith("_") and not attr.startswith("__"): + # This check is for properties, they can be listed in dir + # while at the same time have hasattr return False as long + # as the property getter raises an AttributeError + assert hasattr(estimator, attr), \ + ("By convention, attributes ending with '_' are " + "estimated from data in scikit-learn. Consequently they " + "should not be initialized in the constructor of an " + "estimator but in the fit method. Attribute {!r} " + "was found in estimator {}".format( + attr, type(estimator).__name__)) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_estimators_fit_returns_self(estimator, build_dataset): + """Check if self is returned when calling fit""" + # From scikit-learn + (tuples, y, tuples_train, tuples_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + assert estimator.fit(tuples, y) is estimator + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_pipeline_consistency(estimator, build_dataset): + # From scikit learn + # check that make_pipeline(est) gives same score as est + (tuples, y, tuples_train, tuples_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + pipeline = make_pipeline(estimator) + estimator.fit(tuples, y) + pipeline.fit(tuples, y) + + funcs = ["score", "fit_transform"] + + for func_name in funcs: + func = getattr(estimator, func_name, None) + if func is not None: + func_pipeline = getattr(pipeline, func_name) + result = func(tuples, y) + result_pipe = func_pipeline(tuples, y) + assert_allclose_dense_sparse(result, result_pipe) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_dict_unchanged(estimator, build_dataset): + # From scikit-learn + (tuples, y, tuples_train, tuples_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + if hasattr(estimator, "n_components"): + estimator.n_components = 1 + estimator.fit(tuples, y) + for method in ["predict", "transform", "decision_function", + "predict_proba"]: + if hasattr(estimator, method): + dict_before = estimator.__dict__.copy() + getattr(estimator, method)(tuples) + assert estimator.__dict__ == dict_before, \ + ("Estimator changes __dict__ during %s" + % method) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_dont_overwrite_parameters(estimator, build_dataset): + # From scikit-learn + # check that fit method only changes or sets private attributes + (tuples, y, tuples_train, tuples_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + if hasattr(estimator, "n_components"): + estimator.n_components = 1 + dict_before_fit = estimator.__dict__.copy() + + estimator.fit(tuples, y) + dict_after_fit = estimator.__dict__ + + public_keys_after_fit = [key for key in dict_after_fit.keys() + if is_public_parameter(key)] + + attrs_added_by_fit = [key for key in public_keys_after_fit + if key not in dict_before_fit.keys()] + + # check that fit doesn't add any public attribute + assert not attrs_added_by_fit, \ + ("Estimator adds public attribute(s) during" + " the fit method." + " Estimators are only allowed to add private " + "attributes" + " either started with _ or ended" + " with _ but %s added" % ', '.join(attrs_added_by_fit)) + + # check that fit doesn't change any public attribute + attrs_changed_by_fit = [key for key in public_keys_after_fit + if (dict_before_fit[key] + is not dict_after_fit[key])] + + assert not attrs_changed_by_fit, \ + ("Estimator changes public attribute(s) during" + " the fit method. Estimators are only allowed" + " to change attributes started" + " or ended with _, but" + " %s changed" % ', '.join(attrs_changed_by_fit))