Skip to content

Commit f0ffdfd

Browse files
author
William de Vazelhes
committed
DEP: Add deprecation warnings for num_labels
1 parent 0c0156f commit f0ffdfd

File tree

6 files changed

+100
-11
lines changed

6 files changed

+100
-11
lines changed

metric_learn/itml.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
from __future__ import print_function, absolute_import
17+
import warnings
1718
import numpy as np
1819
from six.moves import xrange
1920
from sklearn.metrics import pairwise_distances
@@ -143,7 +144,8 @@ def metric(self):
143144
class ITML_Supervised(ITML):
144145
"""Information Theoretic Metric Learning (ITML)"""
145146
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
146-
num_constraints=None, bounds=None, A0=None, verbose=False):
147+
num_labeled='deprecated', num_constraints=None, bounds=None,
148+
A0=None, verbose=False):
147149
"""Initialize the supervised version of `ITML`.
148150
149151
`ITML_Supervised` creates pairs of similar sample by taking same class
@@ -156,6 +158,10 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
156158
value for slack variables
157159
max_iter : int, optional
158160
convergence_threshold : float, optional
161+
num_labeled : Not used
162+
.. deprecated:: 0.4.0
163+
`num_labeled` was deprecated in version 0.4.0 and will
164+
be removed in 0.5.0.
159165
num_constraints: int, optional
160166
number of constraints to generate
161167
bounds : list (pos,neg) pairs, optional
@@ -164,10 +170,12 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
164170
initial regularization matrix, defaults to identity
165171
verbose : bool, optional
166172
if True, prints information while learning
173+
learning_rate : Not used
167174
"""
168175
ITML.__init__(self, gamma=gamma, max_iter=max_iter,
169176
convergence_threshold=convergence_threshold,
170177
A0=A0, verbose=verbose)
178+
self.num_labeled = num_labeled
171179
self.num_constraints = num_constraints
172180
self.bounds = bounds
173181

@@ -185,6 +193,10 @@ def fit(self, X, y, random_state=np.random):
185193
random_state : numpy.random.RandomState, optional
186194
If provided, controls random number generation.
187195
"""
196+
if self.num_labeled != 'deprecated':
197+
warnings.warn('"num_labeled" parameter is not used.'
198+
' It has been deprecated in version 0.4 and will be'
199+
'removed in 0.5', DeprecationWarning)
188200
X, y = check_X_y(X, y)
189201
num_constraints = self.num_constraints
190202
if num_constraints is None:

metric_learn/lsml.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from __future__ import print_function, absolute_import, division
11+
import warnings
1112
import numpy as np
1213
import scipy.linalg
1314
from six.moves import xrange
@@ -133,7 +134,8 @@ def _gradient(self, metric):
133134

134135
class LSML_Supervised(LSML):
135136
def __init__(self, tol=1e-3, max_iter=1000, prior=None,
136-
num_constraints=None, weights=None, verbose=False):
137+
num_labeled='deprecated', num_constraints=None, weights=None,
138+
verbose=False):
137139
"""Initialize the supervised version of `LSML`.
138140
139141
`LSML_Supervised` creates quadruplets from labeled samples by taking two
@@ -147,6 +149,10 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None,
147149
max_iter : int, optional
148150
prior : (d x d) matrix, optional
149151
guess at a metric [default: covariance(X)]
152+
num_labeled : Not used
153+
.. deprecated:: 0.4.0
154+
`num_labeled` was deprecated in version 0.4.0 and will
155+
be removed in 0.5.0.
150156
num_constraints: int, optional
151157
number of constraints to generate
152158
weights : (m,) array of floats, optional
@@ -156,6 +162,7 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None,
156162
"""
157163
LSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior,
158164
verbose=verbose)
165+
self.num_labeled = num_labeled
159166
self.num_constraints = num_constraints
160167
self.weights = weights
161168

@@ -173,6 +180,10 @@ def fit(self, X, y, random_state=np.random):
173180
random_state : numpy.random.RandomState, optional
174181
If provided, controls random number generation.
175182
"""
183+
if self.num_labeled != 'deprecated':
184+
warnings.warn('"num_labeled" parameter is not used.'
185+
' It has been deprecated in version 0.4 and will be'
186+
'removed in 0.5', DeprecationWarning)
176187
X, y = check_X_y(X, y)
177188
num_constraints = self.num_constraints
178189
if num_constraints is None:

metric_learn/mmc.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
from __future__ import print_function, absolute_import, division
20+
import warnings
2021
import numpy as np
2122
from six.moves import xrange
2223
from sklearn.metrics import pairwise_distances
@@ -384,8 +385,8 @@ def transformer(self):
384385
class MMC_Supervised(MMC):
385386
"""Mahalanobis Metric for Clustering (MMC)"""
386387
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
387-
num_constraints=None, A0=None, diagonal=False,
388-
diagonal_c=1.0, verbose=False):
388+
num_labeled='deprecated', num_constraints=None, A0=None,
389+
diagonal=False, diagonal_c=1.0, verbose=False):
389390
"""Initialize the supervised version of `MMC`.
390391
391392
`MMC_Supervised` creates pairs of similar sample by taking same class
@@ -397,6 +398,10 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
397398
max_iter : int, optional
398399
max_proj : int, optional
399400
convergence_threshold : float, optional
401+
num_labeled : Not used
402+
.. deprecated:: 0.4.0
403+
`num_labeled` was deprecated in version 0.4.0 and will
404+
be removed in 0.5.0.
400405
num_constraints: int, optional
401406
number of constraints to generate
402407
A0 : (d x d) matrix, optional
@@ -415,6 +420,7 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
415420
convergence_threshold=convergence_threshold,
416421
A0=A0, diagonal=diagonal, diagonal_c=diagonal_c,
417422
verbose=verbose)
423+
self.num_labeled = num_labeled
418424
self.num_constraints = num_constraints
419425

420426
def fit(self, X, y, random_state=np.random):
@@ -429,6 +435,10 @@ def fit(self, X, y, random_state=np.random):
429435
random_state : numpy.random.RandomState, optional
430436
If provided, controls random number generation.
431437
"""
438+
if self.num_labeled != 'deprecated':
439+
warnings.warn('"num_labeled" parameter is not used.'
440+
' It has been deprecated in version 0.4 and will be'
441+
'removed in 0.5', DeprecationWarning)
432442
X, y = check_X_y(X, y)
433443
num_constraints = self.num_constraints
434444
if num_constraints is None:

metric_learn/sdml.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
from __future__ import absolute_import
12+
import warnings
1213
import numpy as np
1314
from scipy.sparse.csgraph import laplacian
1415
from sklearn.covariance import graph_lasso
@@ -82,7 +83,7 @@ def fit(self, X, W):
8283

8384
class SDML_Supervised(SDML):
8485
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
85-
num_constraints=None, verbose=False):
86+
num_labeled='deprecated', num_constraints=None, verbose=False):
8687
"""Initialize the supervised version of `SDML`.
8788
8889
`SDML_Supervised` creates pairs of similar sample by taking same class
@@ -97,6 +98,10 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
9798
trade off between optimizer and sparseness (see graph_lasso)
9899
use_cov : bool, optional
99100
controls prior matrix, will use the identity if use_cov=False
101+
num_labeled : Not used
102+
.. deprecated:: 0.4.0
103+
`num_labeled` was deprecated in version 0.4.0 and will
104+
be removed in 0.5.0.
100105
num_constraints : int, optional
101106
number of constraints to generate
102107
verbose : bool, optional
@@ -105,6 +110,7 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
105110
SDML.__init__(self, balance_param=balance_param,
106111
sparsity_param=sparsity_param, use_cov=use_cov,
107112
verbose=verbose)
113+
self.num_labeled = num_labeled
108114
self.num_constraints = num_constraints
109115

110116
def fit(self, X, y, random_state=np.random):
@@ -125,6 +131,10 @@ def fit(self, X, y, random_state=np.random):
125131
self : object
126132
Returns the instance.
127133
"""
134+
if self.num_labeled != 'deprecated':
135+
warnings.warn('"num_labeled" parameter is not used.'
136+
' It has been deprecated in version 0.4 and will be'
137+
'removed in 0.5', DeprecationWarning)
128138
y = check_array(y, ensure_2d=False)
129139
num_constraints = self.num_constraints
130140
if num_constraints is None:

test/metric_learn_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ def test_iris(self):
5555
csep = class_separation(lsml.transform(), self.iris_labels)
5656
self.assertLess(csep, 0.8) # it's pretty terrible
5757

58+
def test_deprecation(self):
59+
# test that the right deprecation message is thrown.
60+
# TODO: remove in v.0.5
61+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
62+
y = np.array([1, 0, 1, 0])
63+
lsml_supervised = LSML_Supervised(num_labeled=np.inf)
64+
msg = ('"num_labeled" parameter is not used.'
65+
' It has been deprecated in version 0.4 and will be'
66+
'removed in 0.5')
67+
assert_warns_message(DeprecationWarning, msg, lsml_supervised.fit, X, y)
68+
5869

5970
class TestITML(MetricTestCase):
6071
def test_iris(self):
@@ -64,6 +75,17 @@ def test_iris(self):
6475
csep = class_separation(itml.transform(), self.iris_labels)
6576
self.assertLess(csep, 0.2)
6677

78+
def test_deprecation(self):
79+
# test that the right deprecation message is thrown.
80+
# TODO: remove in v.0.5
81+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
82+
y = np.array([1, 0, 1, 0])
83+
itml_supervised = ITML_Supervised(num_labeled=np.inf)
84+
msg = ('"num_labeled" parameter is not used.'
85+
' It has been deprecated in version 0.4 and will be'
86+
'removed in 0.5')
87+
assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y)
88+
6789

6890
class TestLMNN(MetricTestCase):
6991
def test_iris(self):
@@ -118,6 +140,17 @@ def test_iris(self):
118140
csep = class_separation(sdml.transform(), self.iris_labels)
119141
self.assertLess(csep, 0.25)
120142

143+
def test_deprecation(self):
144+
# test that the right deprecation message is thrown.
145+
# TODO: remove in v.0.5
146+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
147+
y = np.array([1, 0, 1, 0])
148+
sdml_supervised = SDML_Supervised(num_labeled=np.inf)
149+
msg = ('"num_labeled" parameter is not used.'
150+
' It has been deprecated in version 0.4 and will be'
151+
'removed in 0.5')
152+
assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y)
153+
121154

122155
class TestNCA(MetricTestCase):
123156
def test_iris(self):
@@ -343,6 +376,17 @@ def test_iris(self):
343376
csep = class_separation(mmc.transform(), self.iris_labels)
344377
self.assertLess(csep, 0.2)
345378

379+
def test_deprecation(self):
380+
# test that the right deprecation message is thrown.
381+
# TODO: remove in v.0.5
382+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
383+
y = np.array([1, 0, 1, 0])
384+
mmc_supervised = MMC_Supervised(num_labeled=np.inf)
385+
msg = ('"num_labeled" parameter is not used.'
386+
' It has been deprecated in version 0.4 and will be'
387+
'removed in 0.5')
388+
assert_warns_message(DeprecationWarning, msg, mmc_supervised.fit, X, y)
389+
346390

347391
@pytest.mark.parametrize(('algo_class', 'dataset'),
348392
[(NCA, make_classification()),

test/test_base_metric.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,27 @@ def test_itml(self):
3030
""".strip('\n'))
3131
self.assertEqual(str(metric_learn.ITML_Supervised()), """
3232
ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0,
33-
max_iter=1000, num_constraints=None, verbose=False)
33+
max_iter=1000, num_constraints=None, num_labeled='deprecated',
34+
verbose=False)
3435
""".strip('\n'))
3536

3637
def test_lsml(self):
3738
self.assertEqual(
3839
str(metric_learn.LSML()),
3940
"LSML(max_iter=1000, prior=None, tol=0.001, verbose=False)")
4041
self.assertEqual(str(metric_learn.LSML_Supervised()), """
41-
LSML_Supervised(max_iter=1000, num_constraints=None, prior=None, tol=0.001,
42-
verbose=False, weights=None)
42+
LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled='deprecated',
43+
prior=None, tol=0.001, verbose=False, weights=None)
4344
""".strip('\n'))
4445

4546
def test_sdml(self):
4647
self.assertEqual(str(metric_learn.SDML()),
4748
"SDML(balance_param=0.5, sparsity_param=0.01, "
4849
"use_cov=True, verbose=False)")
4950
self.assertEqual(str(metric_learn.SDML_Supervised()), """
50-
SDML_Supervised(balance_param=0.5, num_constraints=None, sparsity_param=0.01,
51-
use_cov=True, verbose=False)
51+
SDML_Supervised(balance_param=0.5, num_constraints=None,
52+
num_labeled='deprecated', sparsity_param=0.01, use_cov=True,
53+
verbose=False)
5254
""".strip('\n'))
5355

5456
def test_rca(self):
@@ -71,7 +73,7 @@ def test_mmc(self):
7173
self.assertEqual(str(metric_learn.MMC_Supervised()), """
7274
MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False,
7375
diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None,
74-
verbose=False)
76+
num_labeled='deprecated', verbose=False)
7577
""".strip('\n'))
7678

7779
if __name__ == '__main__':

0 commit comments

Comments
 (0)