Skip to content

Commit 731b327

Browse files
authored
[MRG] Deprecate use_pca parameter of LMNN (#231)
* deprecate use_pca * fix failing test
1 parent a7ede57 commit 731b327

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

metric_learn/lmnn.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ class LMNN(MahalanobisMixin, TransformerMixin):
8787
Tolerance of the optimization procedure. If the objective value varies
8888
less than `tol`, we consider the algorithm has converged and stop it.
8989
90+
use_pca : Not used
91+
92+
.. deprecated:: 0.5.0
93+
`use_pca` was deprecated in version 0.5.0 and will
94+
be removed in 0.6.0.
95+
9096
verbose : bool, optional (default=False)
9197
Whether to print the progress of the optimization procedure.
9298
@@ -151,7 +157,7 @@ class LMNN(MahalanobisMixin, TransformerMixin):
151157

152158
def __init__(self, init=None, k=3, min_iter=50, max_iter=1000,
153159
learn_rate=1e-7, regularization=0.5, convergence_tol=0.001,
154-
use_pca=True, verbose=False, preprocessor=None,
160+
use_pca='deprecated', verbose=False, preprocessor=None,
155161
n_components=None, num_dims='deprecated', random_state=None):
156162
self.init = init
157163
self.k = k
@@ -173,6 +179,11 @@ def fit(self, X, y):
173179
' It has been deprecated in version 0.5.0 and will be'
174180
' removed in 0.6.0. Use "n_components" instead',
175181
DeprecationWarning)
182+
if self.use_pca != 'deprecated':
183+
warnings.warn('"use_pca" parameter is not used.'
184+
' It has been deprecated in version 0.5.0 and will be'
185+
' removed in 0.6.0.',
186+
DeprecationWarning)
176187
k = self.k
177188
reg = self.regularization
178189
learn_rate = self.learn_rate

test/metric_learn_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,18 @@ def test_changed_behaviour_warning(self):
353353
lmnn.fit(X, y)
354354
assert any(msg == str(wrn.message) for wrn in raised_warning)
355355

356+
def test_deprecation_use_pca(self):
357+
# test that a DeprecationWarning is thrown about use_pca, if the
358+
# default parameters are used.
359+
# TODO: remove in v.0.6
360+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
361+
y = np.array([1, 0, 1, 0])
362+
lmnn = LMNN(k=2, use_pca=True)
363+
msg = ('"use_pca" parameter is not used.'
364+
' It has been deprecated in version 0.5.0 and will be'
365+
' removed in 0.6.0.')
366+
assert_warns_message(DeprecationWarning, msg, lmnn.fit, X, y)
367+
356368

357369
@pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]),
358370
[1, 1, 0, 0], 3.0),

test/test_base_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_lmnn(self):
2626
"learn_rate=1e-07, "
2727
"max_iter=1000, min_iter=50, n_components=None, "
2828
"num_dims='deprecated', preprocessor=None, random_state=None, "
29-
"regularization=0.5, use_pca=True, verbose=False)"))
29+
"regularization=0.5, use_pca='deprecated', verbose=False)"))
3030

3131
def test_nca(self):
3232
self.assertEqual(remove_spaces(str(metric_learn.NCA())),

0 commit comments

Comments
 (0)