diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 45c9bbf2..1dbffdd6 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -17,6 +17,7 @@ from six.moves import xrange from sklearn import decomposition from sklearn.base import TransformerMixin +from sklearn.exceptions import ChangedBehaviorWarning from ._util import _check_n_components from .base_metric import MahalanobisMixin @@ -48,7 +49,7 @@ class RCA(MahalanobisMixin, TransformerMixin): """ def __init__(self, n_components=None, num_dims='deprecated', - pca_comps=None, preprocessor=None): + pca_comps='deprecated', preprocessor=None): """Initialize the learner. Parameters @@ -62,12 +63,10 @@ def __init__(self, n_components=None, num_dims='deprecated', `num_dims` was deprecated in version 0.5.0 and will be removed in 0.6.0. Use `n_components` instead. - pca_comps : int, float, None or string - Number of components to keep during PCA preprocessing. - If None (default), does not perform PCA. - If ``0 < pca_comps < 1``, it is used as - the minimum explained variance ratio. - See sklearn.decomposition.PCA for more details. + pca_comps : Not used + .. deprecated:: 0.5.0 + `pca_comps` was deprecated in version 0.5.0 and will + be removed in 0.6.0. preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, @@ -83,8 +82,9 @@ def _check_dimension(self, rank, X): if rank < d: warnings.warn('The inner covariance matrix is not invertible, ' 'so the transformation matrix may contain Nan values. ' - 'You should adjust pca_comps to remove noise and ' - 'redundant information.') + 'You should reduce the dimensionality of your input,' + 'for instance using `sklearn.decomposition.PCA` as a ' + 'preprocessing step.') dim = _check_n_components(d, self.n_components) return dim @@ -105,25 +105,33 @@ def fit(self, X, chunks): ' It has been deprecated in version 0.5.0 and will be' ' removed in 0.6.0. Use "n_components" instead', DeprecationWarning) + + if self.pca_comps != 'deprecated': + warnings.warn( + '"pca_comps" parameter is not used. ' + 'It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. RCA will not do PCA preprocessing anymore. If ' + 'you still want to do it, you could use ' + '`sklearn.decomposition.PCA` and an `sklearn.pipeline.Pipeline`.', + DeprecationWarning) + X, chunks = self._prepare_inputs(X, chunks, ensure_min_samples=2) - # PCA projection to remove noise and redundant information. - if self.pca_comps is not None: - pca = decomposition.PCA(n_components=self.pca_comps) - X_t = pca.fit_transform(X) - M_pca = pca.components_ - else: - X_t = X - X.mean(axis=0) - M_pca = None + warnings.warn( + "RCA will no longer center the data before training. If you want " + "to do some preprocessing, you should do it manually (you can also " + "use an `sklearn.pipeline.Pipeline` for instance). This warning " + "will disappear in version 0.6.0.", ChangedBehaviorWarning) - chunk_mask, chunked_data = _chunk_mean_centering(X_t, chunks) + chunks = np.asanyarray(chunks, dtype=int) + chunk_mask, chunked_data = _chunk_mean_centering(X, chunks) inner_cov = np.atleast_2d(np.cov(chunked_data, rowvar=0, bias=1)) - dim = self._check_dimension(np.linalg.matrix_rank(inner_cov), X_t) + dim = self._check_dimension(np.linalg.matrix_rank(inner_cov), X) # Fisher Linear Discriminant projection - if dim < X_t.shape[1]: - total_cov = np.cov(X_t[chunk_mask], rowvar=0) + if dim < X.shape[1]: + total_cov = np.cov(X[chunk_mask], rowvar=0) tmp = np.linalg.lstsq(total_cov, inner_cov)[0] vals, vecs = np.linalg.eig(tmp) inds = np.argsort(vals)[:dim] @@ -133,9 +141,6 @@ def fit(self, X, chunks): else: self.transformer_ = _inv_sqrtm(inner_cov).T - if M_pca is not None: - self.transformer_ = np.atleast_2d(self.transformer_.dot(M_pca)) - return self @@ -155,7 +160,7 @@ class RCA_Supervised(RCA): """ def __init__(self, num_dims='deprecated', n_components=None, - pca_comps=None, num_chunks=100, chunk_size=2, + pca_comps='deprecated', num_chunks=100, chunk_size=2, preprocessor=None): """Initialize the supervised version of `RCA`. diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 18643363..7b82088e 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -18,9 +18,10 @@ HAS_SKGGM = False else: HAS_SKGGM = True -from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, RCA, +from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, LSML_Supervised, ITML_Supervised, SDML_Supervised, - RCA_Supervised, MMC_Supervised, SDML, ITML, LSML) + RCA_Supervised, MMC_Supervised, SDML, RCA, ITML, + LSML) # Import this specially for testing. from metric_learn.constraints import wrap_pairs from metric_learn.lmnn import python_LMNN, _sum_outer_products @@ -822,6 +823,63 @@ def test_feature_null_variance(self): csep = class_separation(rca.transform(X), self.iris_labels) self.assertLess(csep, 0.30) + def test_deprecation_pca_comps(self): + # test that a deprecation message is thrown if pca_comps is set at + # initialization + # TODO: remove in v.0.6 + X, y = make_classification(random_state=42, n_samples=100) + rca_supervised = RCA_Supervised(pca_comps=X.shape[1], num_chunks=20) + msg = ('"pca_comps" parameter is not used. ' + 'It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. RCA will not do PCA preprocessing anymore. If ' + 'you still want to do it, you could use ' + '`sklearn.decomposition.PCA` and an `sklearn.pipeline.Pipeline`.') + with pytest.warns(ChangedBehaviorWarning) as expected_msg: + rca_supervised.fit(X, y) + assert str(expected_msg[0].message) == msg + + rca = RCA(pca_comps=X.shape[1]) + with pytest.warns(ChangedBehaviorWarning) as expected_msg: + rca.fit(X, y) + assert str(expected_msg[0].message) == msg + + def test_changedbehaviorwarning_preprocessing(self): + # test that a ChangedBehaviorWarning is thrown when using RCA + # TODO: remove in v.0.6 + + msg = ("RCA will no longer center the data before training. If you want " + "to do some preprocessing, you should do it manually (you can also " + "use an `sklearn.pipeline.Pipeline` for instance). This warning " + "will disappear in version 0.6.0.") + + X, y = make_classification(random_state=42, n_samples=100) + rca_supervised = RCA_Supervised(num_chunks=20) + with pytest.warns(ChangedBehaviorWarning) as expected_msg: + rca_supervised.fit(X, y) + assert str(expected_msg[0].message) == msg + + rca = RCA() + with pytest.warns(ChangedBehaviorWarning) as expected_msg: + rca.fit(X, y) + assert str(expected_msg[0].message) == msg + + def test_rank_deficient_returns_warning(self): + """Checks that if the covariance matrix is not invertible, we raise a + warning message advising to use PCA""" + X, y = load_iris(return_X_y=True) + # we make the fourth column a linear combination of the two first, + # so that the covariance matrix will not be invertible: + X[:, 3] = X[:, 0] + 3 * X[:, 1] + rca = RCA() + msg = ('The inner covariance matrix is not invertible, ' + 'so the transformation matrix may contain Nan values. ' + 'You should reduce the dimensionality of your input,' + 'for instance using `sklearn.decomposition.PCA` as a ' + 'preprocessing step.') + with pytest.warns(None) as raised_warnings: + rca.fit(X, y) + assert any(str(w.message) == msg for w in raised_warnings) + @pytest.mark.parametrize('num_dims', [None, 2]) def test_deprecation_num_dims_rca(num_dims): diff --git a/test/test_base_metric.py b/test/test_base_metric.py index 1b312b35..722ff0f3 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -88,13 +88,13 @@ def test_rca(self): self.assertEqual(remove_spaces(str(metric_learn.RCA())), remove_spaces("RCA(n_components=None, " "num_dims='deprecated', " - "pca_comps=None, " + "pca_comps='deprecated', " "preprocessor=None)")) self.assertEqual(remove_spaces(str(metric_learn.RCA_Supervised())), remove_spaces( "RCA_Supervised(chunk_size=2, " "n_components=None, num_chunks=100, " - "num_dims='deprecated', pca_comps=None, " + "num_dims='deprecated', pca_comps='deprecated', " "preprocessor=None)")) def test_mlkr(self):