Skip to content

Commit a9979a8

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] API: remove num_labeled parameter (#119)
* API: remove num_labeled parameter * DEP: Add deprecation warnings for num_labels * MAINT: put deprecation for version 0.5.0 * Revert "MAINT: put deprecation for version 0.5.0" This reverts commit 8727c44. * Revert "Merge remote-tracking branch 'origin/master' into fix/remove_num_labeled_parameter" This reverts commit 944bb3e, reversing changes made to 8727c44. * Revert "Revert "MAINT: put deprecation for version 0.5.0"" This reverts commit bc1eb32. * FIX string representation test wrongly merged * git revert d6bd0d4 * STY fix pep8 errors * STY: fix docstring indentation * FIX remove tests from NCA that are dealt with in #143 * FIX remove nca deprecation test because we remove totally learning rate in the merge #139 * FIX update version * Remove the use of random_subset
1 parent 8658e06 commit a9979a8

File tree

8 files changed

+100
-51
lines changed

8 files changed

+100
-51
lines changed

metric_learn/constraints.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,6 @@ def chunks(self, num_chunks=100, chunk_size=2, random_state=np.random):
8989
(num_chunks, chunk_size))
9090
return chunks
9191

92-
@staticmethod
93-
def random_subset(all_labels, num_preserved=np.inf, random_state=np.random):
94-
"""
95-
the random state object to be passed must be a numpy random seed
96-
"""
97-
n = len(all_labels)
98-
num_ignored = max(0, n - num_preserved)
99-
idx = random_state.randint(n, size=num_ignored)
100-
partial_labels = np.array(all_labels, copy=True)
101-
partial_labels[idx] = -1
102-
return Constraints(partial_labels)
10392

10493
def wrap_pairs(X, constraints):
10594
a = np.array(constraints[0])
@@ -109,4 +98,4 @@ def wrap_pairs(X, constraints):
10998
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
11099
y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))])
111100
pairs = X[constraints]
112-
return pairs, y
101+
return pairs, y

metric_learn/itml.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
from __future__ import print_function, absolute_import
17+
import warnings
1718
import numpy as np
1819
from six.moves import xrange
1920
from sklearn.metrics import pairwise_distances
@@ -172,8 +173,8 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
172173
"""
173174

174175
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
175-
num_labeled=np.inf, num_constraints=None, bounds=None, A0=None,
176-
verbose=False, preprocessor=None):
176+
num_labeled='deprecated', num_constraints=None, bounds=None,
177+
A0=None, verbose=False, preprocessor=None):
177178
"""Initialize the supervised version of `ITML`.
178179
179180
`ITML_Supervised` creates pairs of similar sample by taking same class
@@ -186,10 +187,10 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
186187
value for slack variables
187188
max_iter : int, optional
188189
convergence_threshold : float, optional
189-
num_labeled : int, optional (default=np.inf)
190-
number of labeled points to keep for building pairs. Extra
191-
labeled points will be considered unlabeled, and ignored as such.
192-
Use np.inf (default) to use all labeled points.
190+
num_labeled : Not used
191+
.. deprecated:: 0.5.0
192+
`num_labeled` was deprecated in version 0.5.0 and will
193+
be removed in 0.6.0.
193194
num_constraints: int, optional
194195
number of constraints to generate
195196
bounds : list (pos,neg) pairs, optional
@@ -224,14 +225,17 @@ def fit(self, X, y, random_state=np.random):
224225
random_state : numpy.random.RandomState, optional
225226
If provided, controls random number generation.
226227
"""
228+
if self.num_labeled != 'deprecated':
229+
warnings.warn('"num_labeled" parameter is not used.'
230+
' It has been deprecated in version 0.5.0 and will be'
231+
'removed in 0.6.0', DeprecationWarning)
227232
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
228233
num_constraints = self.num_constraints
229234
if num_constraints is None:
230235
num_classes = len(np.unique(y))
231236
num_constraints = 20 * num_classes**2
232237

233-
c = Constraints.random_subset(y, self.num_labeled,
234-
random_state=random_state)
238+
c = Constraints(y)
235239
pos_neg = c.positive_negative_pairs(num_constraints,
236240
random_state=random_state)
237241
pairs, y = wrap_pairs(X, pos_neg)

metric_learn/lsml.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from __future__ import print_function, absolute_import, division
11+
import warnings
1112
import numpy as np
1213
import scipy.linalg
1314
from six.moves import xrange
@@ -172,8 +173,9 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
172173
metric (See :meth:`transformer_from_metric`.)
173174
"""
174175

175-
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
176-
num_constraints=None, weights=None, verbose=False,
176+
def __init__(self, tol=1e-3, max_iter=1000, prior=None,
177+
num_labeled='deprecated', num_constraints=None, weights=None,
178+
verbose=False,
177179
preprocessor=None):
178180
"""Initialize the supervised version of `LSML`.
179181
@@ -188,10 +190,10 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
188190
max_iter : int, optional
189191
prior : (d x d) matrix, optional
190192
guess at a metric [default: covariance(X)]
191-
num_labeled : int, optional (default=np.inf)
192-
number of labeled points to keep for building quadruplets. Extra
193-
labeled points will be considered unlabeled, and ignored as such.
194-
Use np.inf (default) to use all labeled points.
193+
num_labeled : Not used
194+
.. deprecated:: 0.5.0
195+
`num_labeled` was deprecated in version 0.5.0 and will
196+
be removed in 0.6.0.
195197
num_constraints: int, optional
196198
number of constraints to generate
197199
weights : (m,) array of floats, optional
@@ -222,14 +224,17 @@ def fit(self, X, y, random_state=np.random):
222224
random_state : numpy.random.RandomState, optional
223225
If provided, controls random number generation.
224226
"""
227+
if self.num_labeled != 'deprecated':
228+
warnings.warn('"num_labeled" parameter is not used.'
229+
' It has been deprecated in version 0.5.0 and will be'
230+
'removed in 0.6.0', DeprecationWarning)
225231
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
226232
num_constraints = self.num_constraints
227233
if num_constraints is None:
228234
num_classes = len(np.unique(y))
229235
num_constraints = 20 * num_classes**2
230236

231-
c = Constraints.random_subset(y, self.num_labeled,
232-
random_state=random_state)
237+
c = Constraints(y)
233238
pos_neg = c.positive_negative_pairs(num_constraints, same_length=True,
234239
random_state=random_state)
235240
return _BaseLSML._fit(self, X[np.column_stack(pos_neg)],

metric_learn/mmc.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
from __future__ import print_function, absolute_import, division
20+
import warnings
2021
import numpy as np
2122
from six.moves import xrange
2223
from sklearn.base import TransformerMixin
@@ -389,8 +390,8 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
389390
"""
390391

391392
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
392-
num_labeled=np.inf, num_constraints=None,
393-
A0=None, diagonal=False, diagonal_c=1.0, verbose=False,
393+
num_labeled='deprecated', num_constraints=None, A0=None,
394+
diagonal=False, diagonal_c=1.0, verbose=False,
394395
preprocessor=None):
395396
"""Initialize the supervised version of `MMC`.
396397
@@ -403,10 +404,10 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
403404
max_iter : int, optional
404405
max_proj : int, optional
405406
convergence_threshold : float, optional
406-
num_labeled : int, optional (default=np.inf)
407-
number of labeled points to keep for building pairs. Extra
408-
labeled points will be considered unlabeled, and ignored as such.
409-
Use np.inf (default) to use all labeled points.
407+
num_labeled : Not used
408+
.. deprecated:: 0.5.0
409+
`num_labeled` was deprecated in version 0.5.0 and will
410+
be removed in 0.6.0.
410411
num_constraints: int, optional
411412
number of constraints to generate
412413
A0 : (d x d) matrix, optional
@@ -443,14 +444,17 @@ def fit(self, X, y, random_state=np.random):
443444
random_state : numpy.random.RandomState, optional
444445
If provided, controls random number generation.
445446
"""
447+
if self.num_labeled != 'deprecated':
448+
warnings.warn('"num_labeled" parameter is not used.'
449+
' It has been deprecated in version 0.5.0 and will be'
450+
'removed in 0.6.0', DeprecationWarning)
446451
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
447452
num_constraints = self.num_constraints
448453
if num_constraints is None:
449454
num_classes = len(np.unique(y))
450455
num_constraints = 20 * num_classes**2
451456

452-
c = Constraints.random_subset(y, self.num_labeled,
453-
random_state=random_state)
457+
c = Constraints(y)
454458
pos_neg = c.positive_negative_pairs(num_constraints,
455459
random_state=random_state)
456460
pairs, y = wrap_pairs(X, pos_neg)

metric_learn/sdml.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
from __future__ import absolute_import
12+
import warnings
1213
import numpy as np
1314
from sklearn.base import TransformerMixin
1415
from sklearn.covariance import graph_lasso
@@ -113,7 +114,7 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
113114
"""
114115

115116
def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
116-
num_labeled=np.inf, num_constraints=None, verbose=False,
117+
num_labeled='deprecated', num_constraints=None, verbose=False,
117118
preprocessor=None):
118119
"""Initialize the supervised version of `SDML`.
119120
@@ -128,10 +129,10 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
128129
trade off between optimizer and sparseness (see graph_lasso)
129130
use_cov : bool, optional
130131
controls prior matrix, will use the identity if use_cov=False
131-
num_labeled : int, optional (default=np.inf)
132-
number of labeled points to keep for building pairs. Extra
133-
labeled points will be considered unlabeled, and ignored as such.
134-
Use np.inf (default) to use all labeled points.
132+
num_labeled : Not used
133+
.. deprecated:: 0.5.0
134+
`num_labeled` was deprecated in version 0.5.0 and will
135+
be removed in 0.6.0.
135136
num_constraints : int, optional
136137
number of constraints to generate
137138
verbose : bool, optional
@@ -164,14 +165,17 @@ def fit(self, X, y, random_state=np.random):
164165
self : object
165166
Returns the instance.
166167
"""
168+
if self.num_labeled != 'deprecated':
169+
warnings.warn('"num_labeled" parameter is not used.'
170+
' It has been deprecated in version 0.5.0 and will be'
171+
'removed in 0.6.0', DeprecationWarning)
167172
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
168173
num_constraints = self.num_constraints
169174
if num_constraints is None:
170175
num_classes = len(np.unique(y))
171176
num_constraints = 20 * num_classes**2
172177

173-
c = Constraints.random_subset(y, self.num_labeled,
174-
random_state=random_state)
178+
c = Constraints(y)
175179
pos_neg = c.positive_negative_pairs(num_constraints,
176180
random_state=random_state)
177181
pairs, y = wrap_pairs(X, pos_neg)

test/metric_learn_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ def test_iris(self):
5757
csep = class_separation(lsml.transform(self.iris_points), self.iris_labels)
5858
self.assertLess(csep, 0.8) # it's pretty terrible
5959

60+
def test_deprecation(self):
61+
# test that the right deprecation message is thrown.
62+
# TODO: remove in v.0.5
63+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
64+
y = np.array([1, 0, 1, 0])
65+
lsml_supervised = LSML_Supervised(num_labeled=np.inf)
66+
msg = ('"num_labeled" parameter is not used.'
67+
' It has been deprecated in version 0.5.0 and will be'
68+
'removed in 0.6.0')
69+
assert_warns_message(DeprecationWarning, msg, lsml_supervised.fit, X, y)
70+
6071

6172
class TestITML(MetricTestCase):
6273
def test_iris(self):
@@ -66,6 +77,17 @@ def test_iris(self):
6677
csep = class_separation(itml.transform(self.iris_points), self.iris_labels)
6778
self.assertLess(csep, 0.2)
6879

80+
def test_deprecation(self):
81+
# test that the right deprecation message is thrown.
82+
# TODO: remove in v.0.5
83+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
84+
y = np.array([1, 0, 1, 0])
85+
itml_supervised = ITML_Supervised(num_labeled=np.inf)
86+
msg = ('"num_labeled" parameter is not used.'
87+
' It has been deprecated in version 0.5.0 and will be'
88+
'removed in 0.6.0')
89+
assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y)
90+
6991

7092
class TestLMNN(MetricTestCase):
7193
def test_iris(self):
@@ -121,6 +143,17 @@ def test_iris(self):
121143
csep = class_separation(sdml.transform(self.iris_points), self.iris_labels)
122144
self.assertLess(csep, 0.25)
123145

146+
def test_deprecation(self):
147+
# test that the right deprecation message is thrown.
148+
# TODO: remove in v.0.5
149+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
150+
y = np.array([1, 0, 1, 0])
151+
sdml_supervised = SDML_Supervised(num_labeled=np.inf)
152+
msg = ('"num_labeled" parameter is not used.'
153+
' It has been deprecated in version 0.5.0 and will be'
154+
'removed in 0.6.0')
155+
assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y)
156+
124157

125158
class TestNCA(MetricTestCase):
126159
def test_iris(self):
@@ -335,6 +368,17 @@ def test_iris(self):
335368
csep = class_separation(mmc.transform(self.iris_points), self.iris_labels)
336369
self.assertLess(csep, 0.2)
337370

371+
def test_deprecation(self):
372+
# test that the right deprecation message is thrown.
373+
# TODO: remove in v.0.5
374+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
375+
y = np.array([1, 0, 1, 0])
376+
mmc_supervised = MMC_Supervised(num_labeled=np.inf)
377+
msg = ('"num_labeled" parameter is not used.'
378+
' It has been deprecated in version 0.5.0 and will be'
379+
'removed in 0.6.0')
380+
assert_warns_message(DeprecationWarning, msg, mmc_supervised.fit, X, y)
381+
338382

339383
@pytest.mark.parametrize(('algo_class', 'dataset'),
340384
[(NCA, make_classification()),

test/test_base_metric.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_itml(self):
3232
""".strip('\n'))
3333
self.assertEqual(str(metric_learn.ITML_Supervised()), """
3434
ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0,
35-
max_iter=1000, num_constraints=None, num_labeled=inf,
35+
max_iter=1000, num_constraints=None, num_labeled='deprecated',
3636
preprocessor=None, verbose=False)
3737
""".strip('\n'))
3838

@@ -42,7 +42,7 @@ def test_lsml(self):
4242
"LSML(max_iter=1000, preprocessor=None, prior=None, tol=0.001, "
4343
"verbose=False)")
4444
self.assertEqual(str(metric_learn.LSML_Supervised()), """
45-
LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled=inf,
45+
LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled='deprecated',
4646
preprocessor=None, prior=None, tol=0.001, verbose=False,
4747
weights=None)
4848
""".strip('\n'))
@@ -52,9 +52,9 @@ def test_sdml(self):
5252
"SDML(balance_param=0.5, preprocessor=None, "
5353
"sparsity_param=0.01, use_cov=True,\n verbose=False)")
5454
self.assertEqual(str(metric_learn.SDML_Supervised()), """
55-
SDML_Supervised(balance_param=0.5, num_constraints=None, num_labeled=inf,
56-
preprocessor=None, sparsity_param=0.01, use_cov=True,
57-
verbose=False)
55+
SDML_Supervised(balance_param=0.5, num_constraints=None,
56+
num_labeled='deprecated', preprocessor=None, sparsity_param=0.01,
57+
use_cov=True, verbose=False)
5858
""".strip('\n'))
5959

6060
def test_rca(self):
@@ -78,7 +78,7 @@ def test_mmc(self):
7878
self.assertEqual(str(metric_learn.MMC_Supervised()), """
7979
MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False,
8080
diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None,
81-
num_labeled=inf, preprocessor=None, verbose=False)
81+
num_labeled='deprecated', preprocessor=None, verbose=False)
8282
""".strip('\n'))
8383

8484
if __name__ == '__main__':

test/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def build_data():
5858
input_data, labels = load_iris(return_X_y=True)
5959
X, y = shuffle(input_data, labels, random_state=SEED)
6060
num_constraints = 50
61-
constraints = (
62-
Constraints.random_subset(y, random_state=check_random_state(SEED)))
61+
constraints = Constraints(y)
6362
pairs = (
6463
constraints
6564
.positive_negative_pairs(num_constraints, same_length=True,

0 commit comments

Comments
 (0)