Skip to content

Commit 6273578

Browse files
author
mvargas33
committed
Add all warnings regarding n_constrains
1 parent fa74609 commit 6273578

File tree

7 files changed

+67
-15
lines changed

7 files changed

+67
-15
lines changed

metric_learn/constraints.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import warnings
77
from sklearn.utils import check_random_state
88
from sklearn.neighbors import NearestNeighbors
9-
import warnings
9+
1010

1111
__all__ = ['Constraints']
1212

@@ -56,7 +56,7 @@ def positive_negative_pairs(self, n_constraints, same_length=False,
5656
random_state : int or numpy.RandomState or None, optional (default=None)
5757
A pseudo random number generator object or a seed for it if int.
5858
59-
num_constraints : Renamed to n_constrains. Will be deprecated 0.7.0
59+
num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0
6060
6161
Returns
6262
-------
@@ -72,12 +72,14 @@ def positive_negative_pairs(self, n_constraints, same_length=False,
7272
d : array-like, shape=(n_constraints,)
7373
1D array of indicators for the right elements of negative pairs.
7474
"""
75-
if self.num_constraints != 'deprecated':
76-
warnings.warn('"num_constraints" parameter has been renamed to '
77-
'"n_constraints". It has been deprecated in'
75+
if num_constraints != 'deprecated':
76+
warnings.warn('"num_constraints" parameter has been renamed to'
77+
' "n_constraints". It has been deprecated in'
7878
' version 0.6.3 and will be removed in 0.7.0'
7979
'', FutureWarning)
8080
self.n_constraints = num_constraints
81+
else:
82+
self.n_constraints = n_constraints
8183
random_state = check_random_state(random_state)
8284
a, b = self._pairs(n_constraints, same_label=True,
8385
random_state=random_state)

metric_learn/itml.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
1010
from .constraints import Constraints, wrap_pairs
1111
from ._util import components_from_metric, _initialize_metric_mahalanobis
12+
import warnings
1213

1314

1415
class _BaseITML(MahalanobisMixin):
@@ -302,6 +303,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
302303
case, `random_state` is also used to randomly sample constraints from
303304
labels.
304305
306+
num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0
305307
306308
Attributes
307309
----------
@@ -340,12 +342,22 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
340342

341343
def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3,
342344
n_constraints=None, prior='identity',
343-
verbose=False, preprocessor=None, random_state=None):
345+
verbose=False, preprocessor=None, random_state=None,
346+
num_constraints='deprecated'):
344347
_BaseITML.__init__(self, gamma=gamma, max_iter=max_iter,
345348
tol=tol,
346349
prior=prior, verbose=verbose,
347350
preprocessor=preprocessor, random_state=random_state)
348-
self.n_constraints = n_constraints
351+
if num_constraints != 'deprecated':
352+
warnings.warn('"num_constraints" parameter has been renamed to'
353+
' "n_constraints". It has been deprecated in'
354+
' version 0.6.3 and will be removed in 0.7.0'
355+
'', FutureWarning)
356+
self.n_constraints = num_constraints
357+
else:
358+
self.n_constraints = n_constraints
359+
# Avoid test get_params from failing (all params passed sholud be set)
360+
self.num_constraints = 'deprecated'
349361

350362
def fit(self, X, y, bounds=None):
351363
"""Create constraints from labels and learn the ITML model.

metric_learn/lsml.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin
1010
from .constraints import Constraints
1111
from ._util import components_from_metric, _initialize_metric_mahalanobis
12+
import warnings
1213

1314

1415
class _BaseLSML(MahalanobisMixin):
@@ -282,6 +283,8 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
282283
prior. In any case, `random_state` is also used to randomly sample
283284
constraints from labels.
284285
286+
num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0
287+
285288
Examples
286289
--------
287290
>>> from metric_learn import LSML_Supervised
@@ -304,11 +307,21 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
304307

305308
def __init__(self, tol=1e-3, max_iter=1000, prior='identity',
306309
n_constraints=None, weights=None,
307-
verbose=False, preprocessor=None, random_state=None):
310+
verbose=False, preprocessor=None, random_state=None,
311+
num_constraints='deprecated'):
308312
_BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior,
309313
verbose=verbose, preprocessor=preprocessor,
310314
random_state=random_state)
311-
self.n_constraints = n_constraints
315+
if num_constraints != 'deprecated':
316+
warnings.warn('"num_constraints" parameter has been renamed to'
317+
' "n_constraints". It has been deprecated in'
318+
' version 0.6.3 and will be removed in 0.7.0'
319+
'', FutureWarning)
320+
self.n_constraints = num_constraints
321+
else:
322+
self.n_constraints = n_constraints
323+
# Avoid test get_params from failing (all params passed sholud be set)
324+
self.num_constraints = 'deprecated'
312325
self.weights = weights
313326

314327
def fit(self, X, y):

metric_learn/mmc.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
77
from .constraints import Constraints, wrap_pairs
88
from ._util import components_from_metric, _initialize_metric_mahalanobis
9+
import warnings
910

1011

1112
class _BaseMMC(MahalanobisMixin):
@@ -518,6 +519,8 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
518519
Mahalanobis matrix. In any case, `random_state` is also used to
519520
randomly sample constraints from labels.
520521
522+
num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0
523+
521524
Examples
522525
--------
523526
>>> from metric_learn import MMC_Supervised
@@ -541,13 +544,23 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
541544
def __init__(self, max_iter=100, max_proj=10000, tol=1e-6,
542545
n_constraints=None, init='identity',
543546
diagonal=False, diagonal_c=1.0, verbose=False,
544-
preprocessor=None, random_state=None):
547+
preprocessor=None, random_state=None,
548+
num_constraints='deprecated'):
545549
_BaseMMC.__init__(self, max_iter=max_iter, max_proj=max_proj,
546550
tol=tol,
547551
init=init, diagonal=diagonal,
548552
diagonal_c=diagonal_c, verbose=verbose,
549553
preprocessor=preprocessor, random_state=random_state)
550-
self.n_constraints = n_constraints
554+
if num_constraints != 'deprecated':
555+
warnings.warn('"num_constraints" parameter has been renamed to'
556+
' "n_constraints". It has been deprecated in'
557+
' version 0.6.3 and will be removed in 0.7.0'
558+
'', FutureWarning)
559+
self.n_constraints = num_constraints
560+
else:
561+
self.n_constraints = n_constraints
562+
# Avoid test get_params from failing (all params passed sholud be set)
563+
self.num_constraints = 'deprecated'
551564

552565
def fit(self, X, y):
553566
"""Create constraints from labels and learn the MMC model.

metric_learn/rca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __init__(self, n_components=None, n_chunks=100, chunk_size=2,
182182

183183
def fit(self, X, y):
184184
"""Create constraints from labels and learn the RCA model.
185-
Needs n_constraints specified in constructor.
185+
Needs n_constraints specified in constructor. (Not true?)
186186
187187
Parameters
188188
----------

metric_learn/sdml.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
279279
prior. In any case, `random_state` is also used to randomly sample
280280
constraints from labels.
281281
282+
num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0
283+
282284
Attributes
283285
----------
284286
components_ : `numpy.ndarray`, shape=(n_features, n_features)
@@ -294,12 +296,21 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
294296

295297
def __init__(self, balance_param=0.5, sparsity_param=0.01, prior='identity',
296298
n_constraints=None, verbose=False, preprocessor=None,
297-
random_state=None):
299+
random_state=None, num_constraints='deprecated'):
298300
_BaseSDML.__init__(self, balance_param=balance_param,
299301
sparsity_param=sparsity_param, prior=prior,
300302
verbose=verbose,
301303
preprocessor=preprocessor, random_state=random_state)
302-
self.n_constraints = n_constraints
304+
if num_constraints != 'deprecated':
305+
warnings.warn('"num_constraints" parameter has been renamed to'
306+
' "n_constraints". It has been deprecated in'
307+
' version 0.6.3 and will be removed in 0.7.0'
308+
'', FutureWarning)
309+
self.n_constraints = num_constraints
310+
else:
311+
self.n_constraints = n_constraints
312+
# Avoid test get_params from failing (all params passed sholud be set)
313+
self.num_constraints = 'deprecated'
303314

304315
def fit(self, X, y):
305316
"""Create constraints from labels and learn the SDML model.

test/test_base_metric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def test_lmnn(self):
4444
nndef_kwargs = {'convergence_tol': 0.01, 'n_neighbors': 6}
4545
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
4646
self.assertEqual(
47-
remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01, n_neighbors=6))),
47+
remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01,
48+
n_neighbors=6))),
4849
remove_spaces(f"LMNN({merged_kwargs})"))
4950

5051
def test_nca(self):

0 commit comments

Comments
 (0)