Skip to content

Commit 297ad02

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] change bounds parameter of ITML_Supervised from init to fit (#163)
* MAINT: remove variables not needed to store * Address review #159 (review) * DOC: add more precise docstring * API: put parameter in fit, deprecate it in init, and also change previous deprecation tests names * Change remaining test names
1 parent b336eba commit 297ad02

File tree

3 files changed

+56
-27
lines changed

3 files changed

+56
-27
lines changed

metric_learn/itml.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
202202
"""
203203

204204
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
205-
num_labeled='deprecated', num_constraints=None, bounds=None,
206-
A0=None, verbose=False, preprocessor=None):
205+
num_labeled='deprecated', num_constraints=None,
206+
bounds='deprecated', A0=None, verbose=False, preprocessor=None):
207207
"""Initialize the supervised version of `ITML`.
208208
209209
`ITML_Supervised` creates pairs of similar sample by taking same class
@@ -222,14 +222,11 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
222222
be removed in 0.6.0.
223223
num_constraints: int, optional
224224
number of constraints to generate
225-
bounds : `list` of two numbers
226-
Bounds on similarity, aside slack variables, s.t.
227-
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
228-
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
229-
dissimilar points ``c`` and ``d``, with ``d`` the learned distance.
230-
If not provided at initialization, bounds_[0] and bounds_[1] will be
231-
set to the 5th and 95th percentile of the pairwise distances among all
232-
points in the training data `X`.
225+
bounds : Not used
226+
.. deprecated:: 0.5.0
227+
`bounds` was deprecated in version 0.5.0 and will
228+
be removed in 0.6.0. Set `bounds` at fit time instead :
229+
`itml_supervised.fit(X, y, bounds=...)`
233230
A0 : (d x d) matrix, optional
234231
initial regularization matrix, defaults to identity
235232
verbose : bool, optional
@@ -245,7 +242,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
245242
self.num_constraints = num_constraints
246243
self.bounds = bounds
247244

248-
def fit(self, X, y, random_state=np.random):
245+
def fit(self, X, y, random_state=np.random, bounds=None):
249246
"""Create constraints from labels and learn the ITML model.
250247
251248
@@ -259,11 +256,26 @@ def fit(self, X, y, random_state=np.random):
259256
260257
random_state : numpy.random.RandomState, optional
261258
If provided, controls random number generation.
259+
260+
bounds : `list` of two numbers
261+
Bounds on similarity, aside slack variables, s.t.
262+
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
263+
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
264+
dissimilar points ``c`` and ``d``, with ``d`` the learned distance.
265+
If not provided at initialization, bounds_[0] and bounds_[1] will be
266+
set to the 5th and 95th percentile of the pairwise distances among all
267+
points in the training data `X`.
262268
"""
269+
# TODO: remove these in v0.6.0
263270
if self.num_labeled != 'deprecated':
264271
warnings.warn('"num_labeled" parameter is not used.'
265272
' It has been deprecated in version 0.5.0 and will be'
266273
'removed in 0.6.0', DeprecationWarning)
274+
if self.bounds != 'deprecated':
275+
warnings.warn('"bounds" parameter from initialization is not used.'
276+
' It has been deprecated in version 0.5.0 and will be'
277+
'removed in 0.6.0. Use the "bounds" parameter of this '
278+
'fit method instead.', DeprecationWarning)
267279
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
268280
num_constraints = self.num_constraints
269281
if num_constraints is None:
@@ -274,4 +286,4 @@ def fit(self, X, y, random_state=np.random):
274286
pos_neg = c.positive_negative_pairs(num_constraints,
275287
random_state=random_state)
276288
pairs, y = wrap_pairs(X, pos_neg)
277-
return _BaseITML._fit(self, pairs, y, bounds=self.bounds)
289+
return _BaseITML._fit(self, pairs, y, bounds=bounds)

test/metric_learn_test.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ def test_iris(self):
5757
csep = class_separation(lsml.transform(self.iris_points), self.iris_labels)
5858
self.assertLess(csep, 0.8) # it's pretty terrible
5959

60-
def test_deprecation(self):
61-
# test that the right deprecation message is thrown.
62-
# TODO: remove in v.0.5
60+
def test_deprecation_num_labeled(self):
61+
# test that a deprecation message is thrown if num_labeled is set at
62+
# initialization
63+
# TODO: remove in v.0.6
6364
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
6465
y = np.array([1, 0, 1, 0])
6566
lsml_supervised = LSML_Supervised(num_labeled=np.inf)
@@ -77,9 +78,10 @@ def test_iris(self):
7778
csep = class_separation(itml.transform(self.iris_points), self.iris_labels)
7879
self.assertLess(csep, 0.2)
7980

80-
def test_deprecation(self):
81-
# test that the right deprecation message is thrown.
82-
# TODO: remove in v.0.5
81+
def test_deprecation_num_labeled(self):
82+
# test that a deprecation message is thrown if num_labeled is set at
83+
# initialization
84+
# TODO: remove in v.0.6
8385
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
8486
y = np.array([1, 0, 1, 0])
8587
itml_supervised = ITML_Supervised(num_labeled=np.inf)
@@ -88,6 +90,19 @@ def test_deprecation(self):
8890
'removed in 0.6.0')
8991
assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y)
9092

93+
def test_deprecation_bounds(self):
94+
# test that a deprecation message is thrown if bounds is set at
95+
# initialization
96+
# TODO: remove in v.0.6
97+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
98+
y = np.array([1, 0, 1, 0])
99+
itml_supervised = ITML_Supervised(bounds=None)
100+
msg = ('"bounds" parameter from initialization is not used.'
101+
' It has been deprecated in version 0.5.0 and will be'
102+
'removed in 0.6.0. Use the "bounds" parameter of this '
103+
'fit method instead.')
104+
assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y)
105+
91106

92107
class TestLMNN(MetricTestCase):
93108
def test_iris(self):
@@ -143,9 +158,10 @@ def test_iris(self):
143158
csep = class_separation(sdml.transform(self.iris_points), self.iris_labels)
144159
self.assertLess(csep, 0.25)
145160

146-
def test_deprecation(self):
147-
# test that the right deprecation message is thrown.
148-
# TODO: remove in v.0.5
161+
def test_deprecation_num_labeled(self):
162+
# test that a deprecation message is thrown if num_labeled is set at
163+
# initialization
164+
# TODO: remove in v.0.6
149165
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
150166
y = np.array([1, 0, 1, 0])
151167
sdml_supervised = SDML_Supervised(num_labeled=np.inf)
@@ -370,9 +386,10 @@ def test_iris(self):
370386
csep = class_separation(mmc.transform(self.iris_points), self.iris_labels)
371387
self.assertLess(csep, 0.2)
372388

373-
def test_deprecation(self):
374-
# test that the right deprecation message is thrown.
375-
# TODO: remove in v.0.5
389+
def test_deprecation_num_labeled(self):
390+
# test that a deprecation message is thrown if num_labeled is set at
391+
# initialization
392+
# TODO: remove in v.0.6
376393
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
377394
y = np.array([1, 0, 1, 0])
378395
mmc_supervised = MMC_Supervised(num_labeled=np.inf)

test/test_base_metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def test_itml(self):
3636
preprocessor=None, verbose=False)
3737
""".strip('\n'))
3838
self.assertEqual(str(metric_learn.ITML_Supervised()), """
39-
ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0,
40-
max_iter=1000, num_constraints=None, num_labeled='deprecated',
41-
preprocessor=None, verbose=False)
39+
ITML_Supervised(A0=None, bounds='deprecated', convergence_threshold=0.001,
40+
gamma=1.0, max_iter=1000, num_constraints=None,
41+
num_labeled='deprecated', preprocessor=None, verbose=False)
4242
""".strip('\n'))
4343

4444
def test_lsml(self):

0 commit comments

Comments
 (0)