diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 600d55c0..7c0c2038 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -87,6 +87,12 @@ class LMNN(MahalanobisMixin, TransformerMixin): Tolerance of the optimization procedure. If the objective value varies less than `tol`, we consider the algorithm has converged and stop it. + use_pca : Not used + + .. deprecated:: 0.5.0 + `use_pca` was deprecated in version 0.5.0 and will + be removed in 0.6.0. + verbose : bool, optional (default=False) Whether to print the progress of the optimization procedure. @@ -151,7 +157,7 @@ class LMNN(MahalanobisMixin, TransformerMixin): def __init__(self, init=None, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, - use_pca=True, verbose=False, preprocessor=None, + use_pca='deprecated', verbose=False, preprocessor=None, n_components=None, num_dims='deprecated', random_state=None): self.init = init self.k = k @@ -173,6 +179,11 @@ def fit(self, X, y): ' It has been deprecated in version 0.5.0 and will be' ' removed in 0.6.0. Use "n_components" instead', DeprecationWarning) + if self.use_pca != 'deprecated': + warnings.warn('"use_pca" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0.', + DeprecationWarning) k = self.k reg = self.regularization learn_rate = self.learn_rate diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index c49c9ef5..e05329ab 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -291,6 +291,18 @@ def test_changed_behaviour_warning(self): lmnn.fit(X, y) assert any(msg == str(wrn.message) for wrn in raised_warning) + def test_deprecation_use_pca(self): + # test that a DeprecationWarning is thrown about use_pca, if the + # default parameters are used. + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + lmnn = LMNN(k=2, use_pca=True) + msg = ('"use_pca" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0.') + assert_warns_message(DeprecationWarning, msg, lmnn.fit, X, y) + @pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]), [1, 1, 0, 0], 3.0), diff --git a/test/test_base_metric.py b/test/test_base_metric.py index 313948ec..d3e7802c 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -26,7 +26,7 @@ def test_lmnn(self): "learn_rate=1e-07, " "max_iter=1000, min_iter=50, n_components=None, " "num_dims='deprecated', preprocessor=None, random_state=None, " - "regularization=0.5, use_pca=True, verbose=False)")) + "regularization=0.5, use_pca='deprecated', verbose=False)")) def test_nca(self): self.assertEqual(remove_spaces(str(metric_learn.NCA())),