diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 4316802c..a0ff05f9 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -202,8 +202,8 @@ class ITML_Supervised(_BaseITML, TransformerMixin): """ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, - num_labeled='deprecated', num_constraints=None, bounds=None, - A0=None, verbose=False, preprocessor=None): + num_labeled='deprecated', num_constraints=None, + bounds='deprecated', A0=None, verbose=False, preprocessor=None): """Initialize the supervised version of `ITML`. `ITML_Supervised` creates pairs of similar sample by taking same class @@ -222,14 +222,11 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, be removed in 0.6.0. num_constraints: int, optional number of constraints to generate - bounds : `list` of two numbers - Bounds on similarity, aside slack variables, s.t. - ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` - and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of - dissimilar points ``c`` and ``d``, with ``d`` the learned distance. - If not provided at initialization, bounds_[0] and bounds_[1] will be - set to the 5th and 95th percentile of the pairwise distances among all - points in the training data `X`. + bounds : Not used + .. deprecated:: 0.5.0 + `bounds` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Set `bounds` at fit time instead : + `itml_supervised.fit(X, y, bounds=...)` A0 : (d x d) matrix, optional initial regularization matrix, defaults to identity verbose : bool, optional @@ -245,7 +242,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, self.num_constraints = num_constraints self.bounds = bounds - def fit(self, X, y, random_state=np.random): + def fit(self, X, y, random_state=np.random, bounds=None): """Create constraints from labels and learn the ITML model. @@ -259,11 +256,26 @@ def fit(self, X, y, random_state=np.random): random_state : numpy.random.RandomState, optional If provided, controls random number generation. + + bounds : `list` of two numbers + Bounds on similarity, aside slack variables, s.t. + ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` + and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of + dissimilar points ``c`` and ``d``, with ``d`` the learned distance. + If not provided at initialization, bounds_[0] and bounds_[1] will be + set to the 5th and 95th percentile of the pairwise distances among all + points in the training data `X`. """ + # TODO: remove these in v0.6.0 if self.num_labeled != 'deprecated': warnings.warn('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' 'removed in 0.6.0', DeprecationWarning) + if self.bounds != 'deprecated': + warnings.warn('"bounds" parameter from initialization is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use the "bounds" parameter of this ' + 'fit method instead.', DeprecationWarning) X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: @@ -274,4 +286,4 @@ def fit(self, X, y, random_state=np.random): pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) - return _BaseITML._fit(self, pairs, y, bounds=self.bounds) + return _BaseITML._fit(self, pairs, y, bounds=bounds) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index e4ce8cef..e1eace90 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -57,9 +57,10 @@ def test_iris(self): csep = class_separation(lsml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.8) # it's pretty terrible - def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 + def test_deprecation_num_labeled(self): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) lsml_supervised = LSML_Supervised(num_labeled=np.inf) @@ -77,9 +78,10 @@ def test_iris(self): csep = class_separation(itml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.2) - def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 + def test_deprecation_num_labeled(self): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) itml_supervised = ITML_Supervised(num_labeled=np.inf) @@ -88,6 +90,19 @@ def test_deprecation(self): 'removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) + def test_deprecation_bounds(self): + # test that a deprecation message is thrown if bounds is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + itml_supervised = ITML_Supervised(bounds=None) + msg = ('"bounds" parameter from initialization is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use the "bounds" parameter of this ' + 'fit method instead.') + assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) + class TestLMNN(MetricTestCase): def test_iris(self): @@ -143,9 +158,10 @@ def test_iris(self): csep = class_separation(sdml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) - def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 + def test_deprecation_num_labeled(self): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) sdml_supervised = SDML_Supervised(num_labeled=np.inf) @@ -370,9 +386,10 @@ def test_iris(self): csep = class_separation(mmc.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.2) - def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 + def test_deprecation_num_labeled(self): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) mmc_supervised = MMC_Supervised(num_labeled=np.inf) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index 09718c29..6c9a6dc5 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -36,9 +36,9 @@ def test_itml(self): preprocessor=None, verbose=False) """.strip('\n')) self.assertEqual(str(metric_learn.ITML_Supervised()), """ -ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0, - max_iter=1000, num_constraints=None, num_labeled='deprecated', - preprocessor=None, verbose=False) +ITML_Supervised(A0=None, bounds='deprecated', convergence_threshold=0.001, + gamma=1.0, max_iter=1000, num_constraints=None, + num_labeled='deprecated', preprocessor=None, verbose=False) """.strip('\n')) def test_lsml(self):