From 776ab915a16c6d6a036da27f9cd7c8019c0edd47 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 22 May 2018 15:39:00 +0200 Subject: [PATCH 1/6] Add tests Basically these are the tests from PR https://github.com/metric-learn/metric-learn/pull/85, but reformatted to use pytest, and formed tuples instead of ConstrainedDatasets. --- test/test_weakly_supervised.py | 250 +++++++++++++++++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 test/test_weakly_supervised.py diff --git a/test/test_weakly_supervised.py b/test/test_weakly_supervised.py new file mode 100644 index 00000000..cf17c405 --- /dev/null +++ b/test/test_weakly_supervised.py @@ -0,0 +1,250 @@ +import pytest +from sklearn.cluster import KMeans +from sklearn.datasets import load_iris +from sklearn.decomposition import PCA +from sklearn.pipeline import make_pipeline +from sklearn.utils import shuffle, check_random_state +from sklearn.utils.estimator_checks import is_public_parameter +from sklearn.utils.testing import (assert_allclose_dense_sparse, + set_random_state) + +from metric_learn import ITML, MMC, SDML, LSML +from metric_learn.constraints import wrap_pairs, Constraints +from sklearn import clone +import numpy as np +from sklearn.model_selection import cross_val_score, train_test_split + + +def build_data(): + RNG = check_random_state(0) + dataset = load_iris() + X, y = shuffle(dataset.data, dataset.target, random_state=RNG) + num_constraints = 20 + constraints = Constraints.random_subset(y) + pairs = constraints.positive_negative_pairs(num_constraints, + same_length=True, + random_state=RNG) + return X, pairs + + +def build_pairs(): + # test that you can do cross validation on a ConstrainedDataset with + # a WeaklySupervisedMetricLearner + X, pairs = build_data() + X_constrained, y = wrap_pairs(X, pairs) + X_constrained, y = shuffle(X_constrained, y) + (X_constrained_train, X_constrained_test, y_train, + y_test) = train_test_split(X_constrained, y) + return (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) + + +def build_quadruplets(): + # test that you can do cross validation on a ConstrainedDataset with + # a WeaklySupervisedMetricLearner + X, pairs = build_data() + c = np.column_stack(pairs) + X_constrained = X[c] + X_constrained = shuffle(X_constrained) + y = y_train = y_test = None + X_constrained_train, X_constrained_test = train_test_split(X_constrained) + return (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) + + +list_estimators = [(ITML(), build_pairs), + (LSML(), build_quadruplets), + (MMC(), build_pairs), + (SDML(), build_pairs) + ] + +ids_estimators = ['itml', + 'lsml', + 'mmc', + 'sdml', + ] + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_cross_validation(estimator, build_dataset): + (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + set_random_state(estimator) + + assert np.isfinite(cross_val_score(estimator, X_constrained, y)).all() + + +def check_score(estimator, X_constrained, y): + score = estimator.score(X_constrained, y) + assert np.isfinite(score) + + +def check_predict(estimator, X_constrained): + y_predicted = estimator.predict(X_constrained) + assert len(y_predicted), len(X_constrained) + + +def check_transform(estimator, X_constrained): + X_transformed = estimator.transform(X_constrained) + assert len(X_transformed), len(X_constrained.X) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_simple_estimator(estimator, build_dataset): + (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + set_random_state(estimator) + + estimator.fit(X_constrained_train, y_train) + check_score(estimator, X_constrained_test, y_test) + check_predict(estimator, X_constrained_test) + check_transform(estimator, X_constrained_test) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_pipelining_with_transformer(estimator, build_dataset): + """ + Test that weakly supervised estimators fit well into pipelines + """ + # test in a pipeline with KMeans + (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + set_random_state(estimator) + + pipe = make_pipeline(estimator, KMeans()) + pipe.fit(X_constrained_train, y_train) + check_score(pipe, X_constrained_test, y_test) + check_transform(pipe, X_constrained_test) + # we cannot use check_predict because in this case the shape of the + # output is the shape of X_constrained.X, not X_constrained + y_predicted = pipe.predict(X_constrained) + assert len(y_predicted) == len(X_constrained.X) + + # test in a pipeline with PCA + estimator = clone(estimator) + pipe = make_pipeline(estimator, PCA()) + pipe.fit(X_constrained_train, y_train) + check_transform(pipe, X_constrained_test) + + +@pytest.mark.parametrize('estimator', [est[0] for est in list_estimators], + ids=ids_estimators) +def test_no_fit_attributes_set_in_init(estimator): + """Check that Estimator.__init__ doesn't set trailing-_ attributes.""" + # From scikit-learn + estimator = clone(estimator) + for attr in dir(estimator): + if attr.endswith("_") and not attr.startswith("__"): + # This check is for properties, they can be listed in dir + # while at the same time have hasattr return False as long + # as the property getter raises an AttributeError + assert hasattr(estimator, attr), \ + ("By convention, attributes ending with '_' are " + "estimated from data in scikit-learn. Consequently they " + "should not be initialized in the constructor of an " + "estimator but in the fit method. Attribute {!r} " + "was found in estimator {}".format( + attr, type(estimator).__name__)) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_estimators_fit_returns_self(estimator, build_dataset): + """Check if self is returned when calling fit""" + # From scikit-learn + (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + assert estimator.fit(X_constrained, y) is estimator + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_pipeline_consistency(estimator, build_dataset): + # From scikit learn + # check that make_pipeline(est) gives same score as est + (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + pipeline = make_pipeline(estimator) + estimator.fit(X_constrained, y) + pipeline.fit(X_constrained, y) + + funcs = ["score", "fit_transform"] + + for func_name in funcs: + func = getattr(estimator, func_name, None) + if func is not None: + func_pipeline = getattr(pipeline, func_name) + result = func(X_constrained, y) + result_pipe = func_pipeline(X_constrained, y) + assert_allclose_dense_sparse(result, result_pipe) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_dict_unchanged(estimator, build_dataset): + # From scikit-learn + (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + if hasattr(estimator, "n_components"): + estimator.n_components = 1 + estimator.fit(X_constrained, y) + for method in ["predict", "transform", "decision_function", + "predict_proba"]: + if hasattr(estimator, method): + dict_before = estimator.__dict__.copy() + getattr(estimator, method)(X_constrained) + assert estimator.__dict__ == dict_before, \ + ("Estimator changes __dict__ during %s" + % method) + + +@pytest.mark.parametrize('estimator, build_dataset', list_estimators, + ids=ids_estimators) +def test_dont_overwrite_parameters(estimator, build_dataset): + # From scikit-learn + # check that fit method only changes or sets private attributes + (X_constrained, y, X_constrained_train, X_constrained_test, + y_train, y_test) = build_dataset() + estimator = clone(estimator) + if hasattr(estimator, "n_components"): + estimator.n_components = 1 + dict_before_fit = estimator.__dict__.copy() + + estimator.fit(X_constrained, y) + dict_after_fit = estimator.__dict__ + + public_keys_after_fit = [key for key in dict_after_fit.keys() + if is_public_parameter(key)] + + attrs_added_by_fit = [key for key in public_keys_after_fit + if key not in dict_before_fit.keys()] + + # check that fit doesn't add any public attribute + assert not attrs_added_by_fit, \ + ("Estimator adds public attribute(s) during" + " the fit method." + " Estimators are only allowed to add private " + "attributes" + " either started with _ or ended" + " with _ but %s added" % ', '.join(attrs_added_by_fit)) + + # check that fit doesn't change any public attribute + attrs_changed_by_fit = [key for key in public_keys_after_fit + if (dict_before_fit[key] + is not dict_after_fit[key])] + + assert not attrs_changed_by_fit, \ + ("Estimator changes public attribute(s) during" + " the fit method. Estimators are only allowed" + " to change attributes started" + " or ended with _, but" + " %s changed" % ', '.join(attrs_changed_by_fit)) From 106cbd2fe0454ad323de22746e495b3828fe04e1 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 24 May 2018 11:50:18 +0200 Subject: [PATCH 2/6] Implement scoring functions (and make tests work): - Make PairsClassifierMixin and QuadrupletsClassifierMixin classes, to implement scoring functions - Implement a new API for supervised wrappers of weakly supervised learning estimators (through the use of base classes, (ex: BaseMMC), from which inherit child classes (ex: MMC and MMC_Supervised) (which is the same idea as in PR https://github.com/metric-learn/metric-learn/pull/85 - Delete tests that use tuples learners as transformers (as we do not want to support this behaviour anymore: it is too complicated to allow such different input types (tuples or points) for the same estimator --- metric_learn/base_metric.py | 109 ++++++++++++++++++++++++++++++++- metric_learn/covariance.py | 4 +- metric_learn/itml.py | 24 +++++--- metric_learn/lfda.py | 4 +- metric_learn/lmnn.py | 4 +- metric_learn/lsml.py | 24 +++++--- metric_learn/mlkr.py | 4 +- metric_learn/mmc.py | 27 ++++---- metric_learn/nca.py | 4 +- metric_learn/rca.py | 4 +- metric_learn/sdml.py | 25 +++++--- test/test_weakly_supervised.py | 34 ---------- 12 files changed, 183 insertions(+), 84 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 02519de1..457feda4 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -1,9 +1,11 @@ from numpy.linalg import inv, cholesky from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils.validation import check_array +from sklearn.metrics import roc_auc_score +import numpy as np -class BaseMetricLearner(BaseEstimator, TransformerMixin): +class BaseMetricLearner(BaseEstimator): def __init__(self): raise NotImplementedError('BaseMetricLearner should not be instantiated') @@ -19,6 +21,9 @@ def metric(self): L = self.transformer() return L.T.dot(L) + +class MetricTransformer(TransformerMixin): + def transformer(self): """Computes the transformation matrix from the Mahalanobis matrix. @@ -49,3 +54,105 @@ def transform(self, X=None): X = check_array(X, accept_sparse=True) L = self.transformer() return X.dot(L.T) + + +class _PairsClassifierMixin: + + def predict(self, pairs): + """Predicts the learned similarity between input pairs. + + Returns the learned metric value between samples in every pair. It should + ideally be low for similar samples and high for dissimilar samples. + + Parameters + ---------- + pairs : array-like, shape=(n_constraints, 2, n_features) + A constrained dataset of paired samples. + + Returns + ------- + y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) + The predicted learned metric value between samples in every pair. + """ + pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :] + return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs, + axis=1)) + + def decision_function(self, pairs): + return self.predict(pairs) + + def score(self, pairs, y): + """Computes score of pairs similarity prediction. + + Returns the ``roc_auc`` score of the fitted metric learner. It is + computed in the following way: for every value of a threshold + ``t`` we classify all pairs of samples where the predicted distance is + inferior to ``t`` as belonging to the "similar" class, and the other as + belonging to the "dissimilar" class, and we count false positive and + true positives as in a classical ``roc_auc`` curve. + + Parameters + ---------- + pairs : array-like, shape=(n_constraints, 2, n_features) + Input Pairs. + + y : array-like, shape=(n_constraints,) + The corresponding labels. + + Returns + ------- + score : float + The ``roc_auc`` score. + """ + return roc_auc_score(y, self.decision_function(pairs)) + + +class _QuadrupletsClassifierMixin: + + def predict(self, quadruplets): + """Predicts differences between sample similarities in input quadruplets. + + For each quadruplet of samples, computes the difference between the learned + metric of the first pair minus the learned metric of the second pair. + + Parameters + ---------- + quadruplets : array-like, shape=(n_constraints, 4, n_features) + Input quadruplets. + + Returns + ------- + prediction : np.ndarray of floats, shape=(n_constraints,) + Metric differences. + """ + similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :] + dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :] + return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) * + similar_diffs, axis=1)) - + np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) * + dissimilar_diffs, axis=1))) + + def decision_function(self, quadruplets): + return self.predict(quadruplets) + + def score(self, quadruplets, y=None): + """Computes score on an input constrained dataset + + Returns the accuracy score of the following classification task: a record + is correctly classified if the predicted similarity between the first two + samples is higher than that of the last two. + + Parameters + ---------- + quadruplets : array-like, shape=(n_constraints, 4, n_features) + Input quadruplets. + + y : Ignored, for scikit-learn compatibility. + + Returns + ------- + score : float + The quadruplets score. + """ + predicted_sign = self.decision_function(quadruplets) < 0 + return np.sum(predicted_sign) / predicted_sign.shape[0] diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 8fc07873..689650b4 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -12,10 +12,10 @@ import numpy as np from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer -class Covariance(BaseMetricLearner): +class Covariance(BaseMetricLearner, MetricTransformer): def __init__(self): pass diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 3d9aff2a..bbeebfef 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -19,12 +19,13 @@ from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, + MetricTransformer) from .constraints import Constraints, wrap_pairs from ._util import vector_norm -class ITML(BaseMetricLearner): +class _BaseITML(BaseMetricLearner): """Information Theoretic Metric Learning (ITML)""" def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, A0=None, verbose=False): @@ -80,8 +81,7 @@ def _process_pairs(self, pairs, y, bounds): y = y.astype(bool) return pairs, y - - def fit(self, pairs, y, bounds=None): + def _fit(self, pairs, y, bounds=None): """Learn the ITML model. Parameters @@ -153,7 +153,13 @@ def metric(self): return self.A_ -class ITML_Supervised(ITML): +class ITML(_BaseITML, _PairsClassifierMixin): + + def fit(self, pairs, y, bounds=None): + return self._fit(pairs, y, bounds=bounds) + + +class ITML_Supervised(_BaseITML, MetricTransformer): """Information Theoretic Metric Learning (ITML)""" def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, num_labeled=np.inf, num_constraints=None, bounds=None, A0=None, @@ -177,9 +183,9 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, verbose : bool, optional if True, prints information while learning """ - ITML.__init__(self, gamma=gamma, max_iter=max_iter, - convergence_threshold=convergence_threshold, - A0=A0, verbose=verbose) + _BaseITML.__init__(self, gamma=gamma, max_iter=max_iter, + convergence_threshold=convergence_threshold, + A0=A0, verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints self.bounds = bounds @@ -209,4 +215,4 @@ def fit(self, X, y, random_state=np.random): pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) - return ITML.fit(self, pairs, y, bounds=self.bounds) + return _BaseITML._fit(self, pairs, y, bounds=self.bounds) diff --git a/metric_learn/lfda.py b/metric_learn/lfda.py index 809f092b..03df5f24 100644 --- a/metric_learn/lfda.py +++ b/metric_learn/lfda.py @@ -18,10 +18,10 @@ from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer -class LFDA(BaseMetricLearner): +class LFDA(BaseMetricLearner, MetricTransformer): ''' Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction Sugiyama, ICML 2006 diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index dea12f0c..581dc72a 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -17,11 +17,11 @@ from sklearn.utils.validation import check_X_y, check_array from sklearn.metrics import euclidean_distances -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer # commonality between LMNN implementations -class _base_LMNN(BaseMetricLearner): +class _base_LMNN(BaseMetricLearner, MetricTransformer): def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, use_pca=True, verbose=False): diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index b8b69f19..6aab0bcc 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -13,11 +13,12 @@ from six.moves import xrange from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import (BaseMetricLearner, _QuadrupletsClassifierMixin, + MetricTransformer) from .constraints import Constraints, wrap_pairs -class LSML(BaseMetricLearner): +class _BaseLSML(BaseMetricLearner): def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): """Initialize LSML. @@ -60,7 +61,7 @@ def _prepare_quadruplets(self, quadruplets, weights): def metric(self): return self.M_ - def fit(self, quadruplets, weights=None): + def _fit(self, quadruplets, weights=None): """Learn the LSML model. Parameters @@ -140,7 +141,13 @@ def _gradient(self, metric): return dMetric -class LSML_Supervised(LSML): +class LSML(_BaseLSML, _QuadrupletsClassifierMixin): + + def fit(self, quadruplets, weights=None): + return self._fit(quadruplets, weights=weights) + + +class LSML_Supervised(_BaseLSML, MetricTransformer): def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, num_constraints=None, weights=None, verbose=False): """Initialize the learner. @@ -160,8 +167,8 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, verbose : bool, optional if True, prints information while learning """ - LSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, - verbose=verbose) + _BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, + verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints self.weights = weights @@ -189,5 +196,6 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, same_length=True, - random_state=random_state) - return LSML.fit(self, X[np.column_stack(pos_neg)], weights=self.weights) + random_state=random_state) + return _BaseLSML._fit(self, X[np.column_stack(pos_neg)], + weights=self.weights) diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 35b80495..a16c40aa 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -13,12 +13,12 @@ from sklearn.decomposition import PCA from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer EPS = np.finfo(float).eps -class MLKR(BaseMetricLearner): +class MLKR(BaseMetricLearner, MetricTransformer): """Metric Learning for Kernel Regression (MLKR)""" def __init__(self, num_dims=None, A0=None, epsilon=0.01, alpha=0.0001, max_iter=1000): diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 3f95babd..f04ab33b 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -22,13 +22,13 @@ from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, + MetricTransformer) from .constraints import Constraints, wrap_pairs from ._util import vector_norm - -class MMC(BaseMetricLearner): +class _BaseMMC(BaseMetricLearner): """Mahalanobis Metric for Clustering (MMC)""" def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, A0=None, diagonal=False, diagonal_c=1.0, verbose=False): @@ -58,8 +58,7 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, self.diagonal_c = diagonal_c self.verbose = verbose - - def fit(self, pairs, y): + def _fit(self, pairs, y): """Learn the MMC model. Parameters @@ -390,7 +389,13 @@ def transformer(self): return V.T * np.sqrt(np.maximum(0, w[:,None])) -class MMC_Supervised(MMC): +class MMC(_BaseMMC, _PairsClassifierMixin): + + def fit(self, pairs, y): + return self._fit(pairs, y) + + +class MMC_Supervised(_BaseMMC, MetricTransformer): """Mahalanobis Metric for Clustering (MMC)""" def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, num_labeled=np.inf, num_constraints=None, @@ -418,10 +423,10 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, verbose : bool, optional if True, prints information while learning """ - MMC.__init__(self, max_iter=max_iter, max_proj=max_proj, - convergence_threshold=convergence_threshold, - A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, - verbose=verbose) + _BaseMMC.__init__(self, max_iter=max_iter, max_proj=max_proj, + convergence_threshold=convergence_threshold, + A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, + verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -448,4 +453,4 @@ def fit(self, X, y, random_state=np.random): pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) - return MMC.fit(self, pairs, y) + return _BaseMMC._fit(self, pairs, y) diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 40757d23..9a6af0c3 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -8,12 +8,12 @@ from six.moves import xrange from sklearn.utils.validation import check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer EPS = np.finfo(float).eps -class NCA(BaseMetricLearner): +class NCA(BaseMetricLearner, MetricTransformer): def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01): self.num_dims = num_dims self.max_iter = max_iter diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 0d9b3620..36dd0aae 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -18,7 +18,7 @@ from sklearn import decomposition from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner +from .base_metric import BaseMetricLearner, MetricTransformer from .constraints import Constraints @@ -35,7 +35,7 @@ def _chunk_mean_centering(data, chunks): return chunk_mask, chunk_data -class RCA(BaseMetricLearner): +class RCA(BaseMetricLearner, MetricTransformer): """Relevant Components Analysis (RCA)""" def __init__(self, num_dims=None, pca_comps=None): """Initialize the learner. diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 9378e260..2cefc12e 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -15,11 +15,12 @@ from sklearn.utils.extmath import pinvh from sklearn.utils.validation import check_array, check_X_y -from .base_metric import BaseMetricLearner +from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, + MetricTransformer) from .constraints import Constraints, wrap_pairs -class SDML(BaseMetricLearner): +class _BaseSDML(BaseMetricLearner): def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose=False): """ @@ -57,7 +58,7 @@ def _prepare_pairs(self, pairs, y): def metric(self): return self.M_ - def fit(self, pairs, y): + def _fit(self, pairs, y): """Learn the SDML model. Parameters @@ -81,7 +82,13 @@ def fit(self, pairs, y): return self -class SDML_Supervised(SDML): +class SDML(_BaseSDML, _PairsClassifierMixin): + + def fit(self, pairs, y): + return self._fit(pairs, y) + + +class SDML_Supervised(_BaseSDML, MetricTransformer): def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, num_labeled=np.inf, num_constraints=None, verbose=False): """ @@ -100,9 +107,9 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose : bool, optional if True, prints information while learning """ - SDML.__init__(self, balance_param=balance_param, - sparsity_param=sparsity_param, use_cov=use_cov, - verbose=verbose) + _BaseSDML.__init__(self, balance_param=balance_param, + sparsity_param=sparsity_param, use_cov=use_cov, + verbose=verbose) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -133,7 +140,7 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, - random_state=random_state) + random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) y = 2 * y - 1 - return SDML.fit(self, pairs, y) + return _BaseSDML._fit(self, pairs, y) diff --git a/test/test_weakly_supervised.py b/test/test_weakly_supervised.py index cf17c405..433f621c 100644 --- a/test/test_weakly_supervised.py +++ b/test/test_weakly_supervised.py @@ -86,11 +86,6 @@ def check_predict(estimator, X_constrained): assert len(y_predicted), len(X_constrained) -def check_transform(estimator, X_constrained): - X_transformed = estimator.transform(X_constrained) - assert len(X_transformed), len(X_constrained.X) - - @pytest.mark.parametrize('estimator, build_dataset', list_estimators, ids=ids_estimators) def test_simple_estimator(estimator, build_dataset): @@ -102,35 +97,6 @@ def test_simple_estimator(estimator, build_dataset): estimator.fit(X_constrained_train, y_train) check_score(estimator, X_constrained_test, y_test) check_predict(estimator, X_constrained_test) - check_transform(estimator, X_constrained_test) - - -@pytest.mark.parametrize('estimator, build_dataset', list_estimators, - ids=ids_estimators) -def test_pipelining_with_transformer(estimator, build_dataset): - """ - Test that weakly supervised estimators fit well into pipelines - """ - # test in a pipeline with KMeans - (X_constrained, y, X_constrained_train, X_constrained_test, - y_train, y_test) = build_dataset() - estimator = clone(estimator) - set_random_state(estimator) - - pipe = make_pipeline(estimator, KMeans()) - pipe.fit(X_constrained_train, y_train) - check_score(pipe, X_constrained_test, y_test) - check_transform(pipe, X_constrained_test) - # we cannot use check_predict because in this case the shape of the - # output is the shape of X_constrained.X, not X_constrained - y_predicted = pipe.predict(X_constrained) - assert len(y_predicted) == len(X_constrained.X) - - # test in a pipeline with PCA - estimator = clone(estimator) - pipe = make_pipeline(estimator, PCA()) - pipe.fit(X_constrained_train, y_train) - check_transform(pipe, X_constrained_test) @pytest.mark.parametrize('estimator', [est[0] for est in list_estimators], From 237d467a5b5873c21c711634adfd5ece9599c9eb Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 24 May 2018 13:57:45 +0200 Subject: [PATCH 3/6] fix pep8 errors and unused imports --- metric_learn/base_metric.py | 2 +- metric_learn/lsml.py | 2 +- metric_learn/mmc.py | 1 - test/test_weakly_supervised.py | 4 +--- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 457feda4..c68d5585 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -1,4 +1,4 @@ -from numpy.linalg import inv, cholesky +from numpy.linalg import cholesky from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils.validation import check_array from sklearn.metrics import roc_auc_score diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 6aab0bcc..46c28f1e 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -15,7 +15,7 @@ from .base_metric import (BaseMetricLearner, _QuadrupletsClassifierMixin, MetricTransformer) -from .constraints import Constraints, wrap_pairs +from .constraints import Constraints class _BaseLSML(BaseMetricLearner): diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index f04ab33b..3bb5e3b9 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -19,7 +19,6 @@ from __future__ import print_function, absolute_import, division import numpy as np from six.moves import xrange -from sklearn.metrics import pairwise_distances from sklearn.utils.validation import check_array, check_X_y from .base_metric import (BaseMetricLearner, _PairsClassifierMixin, diff --git a/test/test_weakly_supervised.py b/test/test_weakly_supervised.py index 433f621c..992a5df9 100644 --- a/test/test_weakly_supervised.py +++ b/test/test_weakly_supervised.py @@ -1,7 +1,5 @@ import pytest -from sklearn.cluster import KMeans from sklearn.datasets import load_iris -from sklearn.decomposition import PCA from sklearn.pipeline import make_pipeline from sklearn.utils import shuffle, check_random_state from sklearn.utils.estimator_checks import is_public_parameter @@ -116,7 +114,7 @@ def test_no_fit_attributes_set_in_init(estimator): "should not be initialized in the constructor of an " "estimator but in the fit method. Attribute {!r} " "was found in estimator {}".format( - attr, type(estimator).__name__)) + attr, type(estimator).__name__)) @pytest.mark.parametrize('estimator, build_dataset', list_estimators, From c124ee630906808a76b685fbe67c928b4c9b8539 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 24 May 2018 14:53:23 +0200 Subject: [PATCH 4/6] let the transformer function inside BaseMetricLearner --- metric_learn/base_metric.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index c68d5585..463eb84a 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -21,9 +21,6 @@ def metric(self): L = self.transformer() return L.T.dot(L) - -class MetricTransformer(TransformerMixin): - def transformer(self): """Computes the transformation matrix from the Mahalanobis matrix. @@ -35,6 +32,9 @@ def transformer(self): """ return cholesky(self.metric()).T + +class MetricTransformer(TransformerMixin): + def transform(self, X=None): """Applies the metric transformation. From a70d1a856a6b7aeb6ff5800e5f03ffc48bfb6163 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Fri, 25 May 2018 11:39:32 +0200 Subject: [PATCH 5/6] FIX move docstrings from _fit to fit --- metric_learn/itml.py | 32 ++++++++++++++++---------------- metric_learn/lsml.py | 34 +++++++++++++++++----------------- metric_learn/mmc.py | 28 ++++++++++++++-------------- metric_learn/sdml.py | 24 ++++++++++++------------ 4 files changed, 59 insertions(+), 59 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index bddee467..fc839611 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -80,22 +80,6 @@ def _process_pairs(self, pairs, y, bounds): return pairs, y def _fit(self, pairs, y, bounds=None): - """Learn the ITML model. - - Parameters - ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) - Array of pairs. Each row corresponds to two points. - y: array-like, of shape (n_constraints,) - Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - bounds : list (pos,neg) pairs, optional - bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg - - Returns - ------- - self : object - Returns the instance. - """ pairs, y = self._process_pairs(pairs, y, bounds) gamma = self.gamma pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] @@ -154,6 +138,22 @@ def metric(self): class ITML(_BaseITML, _PairsClassifierMixin): def fit(self, pairs, y, bounds=None): + """Learn the ITML model. + + Parameters + ---------- + pairs: array-like, shape=(n_constraints, 2, n_features) + Array of pairs. Each row corresponds to two points. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. + bounds : list (pos,neg) pairs, optional + bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg + + Returns + ------- + self : object + Returns the instance. + """ return self._fit(pairs, y, bounds=bounds) diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 46c28f1e..cdbc75d5 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -62,23 +62,6 @@ def metric(self): return self.M_ def _fit(self, quadruplets, weights=None): - """Learn the LSML model. - - Parameters - ---------- - quadruplets : array-like, shape=(n_constraints, 4, n_features) - Each row corresponds to 4 points. In order to supervise the - algorithm in the right way, we should have the four samples ordered - in a way such that: d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) - for all 0 <= i < n_constraints. - weights : (n_constraints,) array of floats, optional - scale factor for each constraint - - Returns - ------- - self : object - Returns the instance. - """ self._prepare_quadruplets(quadruplets, weights) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. @@ -144,6 +127,23 @@ def _gradient(self, metric): class LSML(_BaseLSML, _QuadrupletsClassifierMixin): def fit(self, quadruplets, weights=None): + """Learn the LSML model. + + Parameters + ---------- + quadruplets : array-like, shape=(n_constraints, 4, n_features) + Each row corresponds to 4 points. In order to supervise the + algorithm in the right way, we should have the four samples ordered + in a way such that: d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) + for all 0 <= i < n_constraints. + weights : (n_constraints,) array of floats, optional + scale factor for each constraint + + Returns + ------- + self : object + Returns the instance. + """ return self._fit(quadruplets, weights=weights) diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 421966a9..f61bb1c7 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -58,20 +58,6 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, self.verbose = verbose def _fit(self, pairs, y): - """Learn the MMC model. - - Parameters - ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) - Array of pairs. Each row corresponds to two points. - y: array-like, of shape (n_constraints,) - Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - - Returns - ------- - self : object - Returns the instance. - """ pairs, y = self._process_pairs(pairs, y) if self.diagonal: return self._fit_diag(pairs, y) @@ -389,6 +375,20 @@ def transformer(self): class MMC(_BaseMMC, _PairsClassifierMixin): def fit(self, pairs, y): + """Learn the MMC model. + + Parameters + ---------- + pairs: array-like, shape=(n_constraints, 2, n_features) + Array of pairs. Each row corresponds to two points. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. + + Returns + ------- + self : object + Returns the instance. + """ return self._fit(pairs, y) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index b8c08090..2e40ad91 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -59,6 +59,18 @@ def metric(self): return self.M_ def _fit(self, pairs, y): + loss_matrix = self._prepare_pairs(pairs, y) + P = self.M_ + self.balance_param * loss_matrix + emp_cov = pinvh(P) + # hack: ensure positive semidefinite + emp_cov = emp_cov.T.dot(emp_cov) + _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) + return self + + +class SDML(_BaseSDML, _PairsClassifierMixin): + + def fit(self, pairs, y): """Learn the SDML model. Parameters @@ -73,18 +85,6 @@ def _fit(self, pairs, y): self : object Returns the instance. """ - loss_matrix = self._prepare_pairs(pairs, y) - P = self.M_ + self.balance_param * loss_matrix - emp_cov = pinvh(P) - # hack: ensure positive semidefinite - emp_cov = emp_cov.T.dot(emp_cov) - _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) - return self - - -class SDML(_BaseSDML, _PairsClassifierMixin): - - def fit(self, pairs, y): return self._fit(pairs, y) From b741a9ee423aba5e41253dcc652acc120c168b39 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 5 Jun 2018 12:01:13 +0200 Subject: [PATCH 6/6] FIX: corrections according to reviews https://github.com/metric-learn/metric-learn/pull/95#pullrequestreview-123870498 and https://github.com/metric-learn/metric-learn/pull/95#pullrequestreview-124653719 - replace similarity by metric - replace constrained dataset by pairs/quadruplets - simplify score on quadruplets expression - replace ``X_constrained`` in tests by pairs/quadruplets/tuples --- metric_learn/base_metric.py | 13 +++---- test/test_weakly_supervised.py | 68 +++++++++++++++++----------------- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 463eb84a..889de999 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -59,7 +59,7 @@ def transform(self, X=None): class _PairsClassifierMixin: def predict(self, pairs): - """Predicts the learned similarity between input pairs. + """Predicts the learned metric between input pairs. Returns the learned metric value between samples in every pair. It should ideally be low for similar samples and high for dissimilar samples. @@ -67,7 +67,7 @@ def predict(self, pairs): Parameters ---------- pairs : array-like, shape=(n_constraints, 2, n_features) - A constrained dataset of paired samples. + Input pairs. Returns ------- @@ -110,7 +110,7 @@ def score(self, pairs, y): class _QuadrupletsClassifierMixin: def predict(self, quadruplets): - """Predicts differences between sample similarities in input quadruplets. + """Predicts differences between sample distances in input quadruplets. For each quadruplet of samples, computes the difference between the learned metric of the first pair minus the learned metric of the second pair. @@ -122,7 +122,7 @@ def predict(self, quadruplets): Returns ------- - prediction : np.ndarray of floats, shape=(n_constraints,) + prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :] @@ -136,7 +136,7 @@ def decision_function(self, quadruplets): return self.predict(quadruplets) def score(self, quadruplets, y=None): - """Computes score on an input constrained dataset + """Computes score on input quadruplets Returns the accuracy score of the following classification task: a record is correctly classified if the predicted similarity between the first two @@ -154,5 +154,4 @@ def score(self, quadruplets, y=None): score : float The quadruplets score. """ - predicted_sign = self.decision_function(quadruplets) < 0 - return np.sum(predicted_sign) / predicted_sign.shape[0] + return - np.mean(np.sign(self.decision_function(quadruplets))) diff --git a/test/test_weakly_supervised.py b/test/test_weakly_supervised.py index 992a5df9..6386d22a 100644 --- a/test/test_weakly_supervised.py +++ b/test/test_weakly_supervised.py @@ -26,27 +26,27 @@ def build_data(): def build_pairs(): - # test that you can do cross validation on a ConstrainedDataset with + # test that you can do cross validation on tuples of points with # a WeaklySupervisedMetricLearner X, pairs = build_data() - X_constrained, y = wrap_pairs(X, pairs) - X_constrained, y = shuffle(X_constrained, y) - (X_constrained_train, X_constrained_test, y_train, - y_test) = train_test_split(X_constrained, y) - return (X_constrained, y, X_constrained_train, X_constrained_test, + pairs, y = wrap_pairs(X, pairs) + pairs, y = shuffle(pairs, y) + (pairs_train, pairs_test, y_train, + y_test) = train_test_split(pairs, y) + return (pairs, y, pairs_train, pairs_test, y_train, y_test) def build_quadruplets(): - # test that you can do cross validation on a ConstrainedDataset with + # test that you can do cross validation on a tuples of points with # a WeaklySupervisedMetricLearner X, pairs = build_data() c = np.column_stack(pairs) - X_constrained = X[c] - X_constrained = shuffle(X_constrained) + quadruplets = X[c] + quadruplets = shuffle(quadruplets) y = y_train = y_test = None - X_constrained_train, X_constrained_test = train_test_split(X_constrained) - return (X_constrained, y, X_constrained_train, X_constrained_test, + quadruplets_train, quadruplets_test = train_test_split(quadruplets) + return (quadruplets, y, quadruplets_train, quadruplets_test, y_train, y_test) @@ -66,35 +66,35 @@ def build_quadruplets(): @pytest.mark.parametrize('estimator, build_dataset', list_estimators, ids=ids_estimators) def test_cross_validation(estimator, build_dataset): - (X_constrained, y, X_constrained_train, X_constrained_test, + (tuples, y, tuples_train, tuples_test, y_train, y_test) = build_dataset() estimator = clone(estimator) set_random_state(estimator) - assert np.isfinite(cross_val_score(estimator, X_constrained, y)).all() + assert np.isfinite(cross_val_score(estimator, tuples, y)).all() -def check_score(estimator, X_constrained, y): - score = estimator.score(X_constrained, y) +def check_score(estimator, tuples, y): + score = estimator.score(tuples, y) assert np.isfinite(score) -def check_predict(estimator, X_constrained): - y_predicted = estimator.predict(X_constrained) - assert len(y_predicted), len(X_constrained) +def check_predict(estimator, tuples): + y_predicted = estimator.predict(tuples) + assert len(y_predicted), len(tuples) @pytest.mark.parametrize('estimator, build_dataset', list_estimators, ids=ids_estimators) def test_simple_estimator(estimator, build_dataset): - (X_constrained, y, X_constrained_train, X_constrained_test, + (tuples, y, tuples_train, tuples_test, y_train, y_test) = build_dataset() estimator = clone(estimator) set_random_state(estimator) - estimator.fit(X_constrained_train, y_train) - check_score(estimator, X_constrained_test, y_test) - check_predict(estimator, X_constrained_test) + estimator.fit(tuples_train, y_train) + check_score(estimator, tuples_test, y_test) + check_predict(estimator, tuples_test) @pytest.mark.parametrize('estimator', [est[0] for est in list_estimators], @@ -122,10 +122,10 @@ def test_no_fit_attributes_set_in_init(estimator): def test_estimators_fit_returns_self(estimator, build_dataset): """Check if self is returned when calling fit""" # From scikit-learn - (X_constrained, y, X_constrained_train, X_constrained_test, + (tuples, y, tuples_train, tuples_test, y_train, y_test) = build_dataset() estimator = clone(estimator) - assert estimator.fit(X_constrained, y) is estimator + assert estimator.fit(tuples, y) is estimator @pytest.mark.parametrize('estimator, build_dataset', list_estimators, @@ -133,12 +133,12 @@ def test_estimators_fit_returns_self(estimator, build_dataset): def test_pipeline_consistency(estimator, build_dataset): # From scikit learn # check that make_pipeline(est) gives same score as est - (X_constrained, y, X_constrained_train, X_constrained_test, + (tuples, y, tuples_train, tuples_test, y_train, y_test) = build_dataset() estimator = clone(estimator) pipeline = make_pipeline(estimator) - estimator.fit(X_constrained, y) - pipeline.fit(X_constrained, y) + estimator.fit(tuples, y) + pipeline.fit(tuples, y) funcs = ["score", "fit_transform"] @@ -146,8 +146,8 @@ def test_pipeline_consistency(estimator, build_dataset): func = getattr(estimator, func_name, None) if func is not None: func_pipeline = getattr(pipeline, func_name) - result = func(X_constrained, y) - result_pipe = func_pipeline(X_constrained, y) + result = func(tuples, y) + result_pipe = func_pipeline(tuples, y) assert_allclose_dense_sparse(result, result_pipe) @@ -155,17 +155,17 @@ def test_pipeline_consistency(estimator, build_dataset): ids=ids_estimators) def test_dict_unchanged(estimator, build_dataset): # From scikit-learn - (X_constrained, y, X_constrained_train, X_constrained_test, + (tuples, y, tuples_train, tuples_test, y_train, y_test) = build_dataset() estimator = clone(estimator) if hasattr(estimator, "n_components"): estimator.n_components = 1 - estimator.fit(X_constrained, y) + estimator.fit(tuples, y) for method in ["predict", "transform", "decision_function", "predict_proba"]: if hasattr(estimator, method): dict_before = estimator.__dict__.copy() - getattr(estimator, method)(X_constrained) + getattr(estimator, method)(tuples) assert estimator.__dict__ == dict_before, \ ("Estimator changes __dict__ during %s" % method) @@ -176,14 +176,14 @@ def test_dict_unchanged(estimator, build_dataset): def test_dont_overwrite_parameters(estimator, build_dataset): # From scikit-learn # check that fit method only changes or sets private attributes - (X_constrained, y, X_constrained_train, X_constrained_test, + (tuples, y, tuples_train, tuples_test, y_train, y_test) = build_dataset() estimator = clone(estimator) if hasattr(estimator, "n_components"): estimator.n_components = 1 dict_before_fit = estimator.__dict__.copy() - estimator.fit(X_constrained, y) + estimator.fit(tuples, y) dict_after_fit = estimator.__dict__ public_keys_after_fit = [key for key in dict_after_fit.keys()