Skip to content

Commit 3899653

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] Uniformize num_dims to n_components and add it for LMNN (#193)
* Uniformize num_dims and add it for LMNN * MAINT: fix imports * Fix: fix test_num_dims * MAINT: Address #193 (review) * Refactor num_dims in n_components and add deprecation * FIX make some tests work * FIX Make tests work (fix deprecation messages and fix RCA example) * Remove unused import * Revert "Remove unused import" This reverts commit 81c9a8d. * Fix import * FIX fix some tests * Allow more general sign switching in test_lfda
1 parent efba316 commit 3899653

19 files changed

+327
-121
lines changed

metric_learn/_util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,13 @@ def validate_vector(u, dtype=None):
411411
if u.ndim > 1:
412412
raise ValueError("Input vector should be 1-D.")
413413
return u
414+
415+
416+
def _check_n_components(n_features, n_components):
417+
"""Checks that n_components is less than n_features and deal with the None
418+
case"""
419+
if n_components is None:
420+
return n_features
421+
if 0 < n_components <= n_features:
422+
return n_components
423+
raise ValueError('Invalid n_components, must be in [1, %d]' % n_features)

metric_learn/base_metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner,
172172
173173
Attributes
174174
----------
175-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
175+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
176176
The learned linear transformation ``L``.
177177
"""
178178

@@ -232,7 +232,7 @@ def transform(self, X):
232232
233233
Returns
234234
-------
235-
X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims)
235+
X_embedded : `numpy.ndarray`, shape=(n_samples, n_components)
236236
The embedded data points.
237237
"""
238238
X_checked = check_input(X, type_of_inputs='classic', estimator=self,
@@ -288,7 +288,7 @@ def get_mahalanobis_matrix(self):
288288
289289
Returns
290290
-------
291-
M : `numpy.ndarray`, shape=(n_components, n_features)
291+
M : `numpy.ndarray`, shape=(n_features, n_features)
292292
The copy of the learned Mahalanobis matrix.
293293
"""
294294
return self.transformer_.T.dot(self.transformer_)

metric_learn/covariance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Covariance(MahalanobisMixin, TransformerMixin):
2222
2323
Attributes
2424
----------
25-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
25+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
2626
The linear transformation ``L`` deduced from the learned Mahalanobis
2727
metric (See function `transformer_from_metric`.)
2828
"""

metric_learn/itml.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
150150
n_iter_ : `int`
151151
The number of iterations the solver has run.
152152
153-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
153+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
154154
The linear transformation ``L`` deduced from the learned Mahalanobis
155155
metric (See function `transformer_from_metric`.)
156156
@@ -218,7 +218,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
218218
n_iter_ : `int`
219219
The number of iterations the solver has run.
220220
221-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
221+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
222222
The linear transformation ``L`` deduced from the learned Mahalanobis
223223
metric (See function `transformer_from_metric`.)
224224
"""
@@ -292,11 +292,11 @@ def fit(self, X, y, random_state=np.random, bounds=None):
292292
if self.num_labeled != 'deprecated':
293293
warnings.warn('"num_labeled" parameter is not used.'
294294
' It has been deprecated in version 0.5.0 and will be'
295-
'removed in 0.6.0', DeprecationWarning)
295+
' removed in 0.6.0', DeprecationWarning)
296296
if self.bounds != 'deprecated':
297297
warnings.warn('"bounds" parameter from initialization is not used.'
298298
' It has been deprecated in version 0.5.0 and will be'
299-
'removed in 0.6.0. Use the "bounds" parameter of this '
299+
' removed in 0.6.0. Use the "bounds" parameter of this '
300300
'fit method instead.', DeprecationWarning)
301301
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
302302
num_constraints = self.num_constraints

metric_learn/lfda.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from six.moves import xrange
1717
from sklearn.metrics import pairwise_distances
1818
from sklearn.base import TransformerMixin
19+
20+
from ._util import _check_n_components
1921
from .base_metric import MahalanobisMixin
2022

2123

@@ -26,23 +28,29 @@ class LFDA(MahalanobisMixin, TransformerMixin):
2628
2729
Attributes
2830
----------
29-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
31+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
3032
The learned linear transformation ``L``.
3133
'''
3234

33-
def __init__(self, num_dims=None, k=None, embedding_type='weighted',
34-
preprocessor=None):
35+
def __init__(self, n_components=None, num_dims='deprecated',
36+
k=None, embedding_type='weighted', preprocessor=None):
3537
'''
3638
Initialize LFDA.
3739
3840
Parameters
3941
----------
40-
num_dims : int, optional
41-
Dimensionality of reduced space (defaults to dimension of X)
42+
n_components : int or None, optional (default=None)
43+
Dimensionality of reduced space (if None, defaults to dimension of X).
44+
45+
num_dims : Not used
46+
47+
.. deprecated:: 0.5.0
48+
`num_dims` was deprecated in version 0.5.0 and will
49+
be removed in 0.6.0. Use `n_components` instead.
4250
4351
k : int, optional
4452
Number of nearest neighbors used in local scaling method.
45-
Defaults to min(7, num_dims - 1).
53+
Defaults to min(7, n_components - 1).
4654
4755
embedding_type : str, optional
4856
Type of metric in the embedding space (default: 'weighted')
@@ -56,6 +64,7 @@ def __init__(self, num_dims=None, k=None, embedding_type='weighted',
5664
'''
5765
if embedding_type not in ('weighted', 'orthonormalized', 'plain'):
5866
raise ValueError('Invalid embedding_type: %r' % embedding_type)
67+
self.n_components = n_components
5968
self.num_dims = num_dims
6069
self.embedding_type = embedding_type
6170
self.k = k
@@ -72,17 +81,17 @@ def fit(self, X, y):
7281
y : (n,) array-like
7382
Class labels, one per point of data.
7483
'''
84+
if self.num_dims != 'deprecated':
85+
warnings.warn('"num_dims" parameter is not used.'
86+
' It has been deprecated in version 0.5.0 and will be'
87+
' removed in 0.6.0. Use "n_components" instead',
88+
DeprecationWarning)
7589
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
7690
unique_classes, y = np.unique(y, return_inverse=True)
7791
n, d = X.shape
7892
num_classes = len(unique_classes)
7993

80-
if self.num_dims is None:
81-
dim = d
82-
else:
83-
if not 0 < self.num_dims <= d:
84-
raise ValueError('Invalid num_dims, must be in [1,%d]' % d)
85-
dim = self.num_dims
94+
dim = _check_n_components(d, self.n_components)
8695

8796
if self.k is None:
8897
k = min(7, d - 1)

metric_learn/lmnn.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919
from six.moves import xrange
2020
from sklearn.metrics import euclidean_distances
2121
from sklearn.base import TransformerMixin
22+
23+
from ._util import _check_n_components
2224
from .base_metric import MahalanobisMixin
2325

2426

2527
# commonality between LMNN implementations
2628
class _base_LMNN(MahalanobisMixin, TransformerMixin):
2729
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
2830
regularization=0.5, convergence_tol=0.001, use_pca=True,
29-
verbose=False, preprocessor=None):
31+
verbose=False, preprocessor=None, n_components=None,
32+
num_dims='deprecated'):
3033
"""Initialize the LMNN object.
3134
3235
Parameters
@@ -40,6 +43,15 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
4043
preprocessor : array-like, shape=(n_samples, n_features) or callable
4144
The preprocessor to call to get tuples from indices. If array-like,
4245
tuples will be formed like this: X[indices].
46+
47+
n_components : int or None, optional (default=None)
48+
Dimensionality of reduced space (if None, defaults to dimension of X).
49+
50+
num_dims : Not used
51+
52+
.. deprecated:: 0.5.0
53+
`num_dims` was deprecated in version 0.5.0 and will
54+
be removed in 0.6.0. Use `n_components` instead.
4355
"""
4456
self.k = k
4557
self.min_iter = min_iter
@@ -49,27 +61,35 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
4961
self.convergence_tol = convergence_tol
5062
self.use_pca = use_pca
5163
self.verbose = verbose
64+
self.n_components = n_components
65+
self.num_dims = num_dims
5266
super(_base_LMNN, self).__init__(preprocessor)
5367

5468

5569
# slower Python version
5670
class python_LMNN(_base_LMNN):
5771

5872
def fit(self, X, y):
73+
if self.num_dims != 'deprecated':
74+
warnings.warn('"num_dims" parameter is not used.'
75+
' It has been deprecated in version 0.5.0 and will be'
76+
' removed in 0.6.0. Use "n_components" instead',
77+
DeprecationWarning)
5978
k = self.k
6079
reg = self.regularization
6180
learn_rate = self.learn_rate
6281

6382
X, y = self._prepare_inputs(X, y, dtype=float,
6483
ensure_min_samples=2)
65-
num_pts, num_dims = X.shape
84+
num_pts, d = X.shape
85+
output_dim = _check_n_components(d, self.n_components)
6686
unique_labels, label_inds = np.unique(y, return_inverse=True)
6787
if len(label_inds) != num_pts:
6888
raise ValueError('Must have one label per point.')
6989
self.labels_ = np.arange(len(unique_labels))
7090
if self.use_pca:
7191
warnings.warn('use_pca does nothing for the python_LMNN implementation')
72-
self.transformer_ = np.eye(num_dims)
92+
self.transformer_ = np.eye(output_dim, d)
7393
required_k = np.bincount(label_inds).min()
7494
if self.k > required_k:
7595
raise ValueError('not enough class labels for specified k'
@@ -272,7 +292,7 @@ class LMNN(_base_LMNN):
272292
n_iter_ : `int`
273293
The number of iterations the solver has run.
274294
275-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
295+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
276296
The learned linear transformation ``L``.
277297
"""
278298

metric_learn/lsml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
146146
n_iter_ : `int`
147147
The number of iterations the solver has run.
148148
149-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
149+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
150150
The linear transformation ``L`` deduced from the learned Mahalanobis
151151
metric (See function `transformer_from_metric`.)
152152
"""
@@ -182,7 +182,7 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
182182
n_iter_ : `int`
183183
The number of iterations the solver has run.
184184
185-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
185+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
186186
The linear transformation ``L`` deduced from the learned Mahalanobis
187187
metric (See function `transformer_from_metric`.)
188188
"""
@@ -241,7 +241,7 @@ def fit(self, X, y, random_state=np.random):
241241
if self.num_labeled != 'deprecated':
242242
warnings.warn('"num_labeled" parameter is not used.'
243243
' It has been deprecated in version 0.5.0 and will be'
244-
'removed in 0.6.0', DeprecationWarning)
244+
' removed in 0.6.0', DeprecationWarning)
245245
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
246246
num_constraints = self.num_constraints
247247
if num_constraints is None:

metric_learn/mlkr.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424

2525
from sklearn.metrics import pairwise_distances
26+
27+
from metric_learn._util import _check_n_components
2628
from .base_metric import MahalanobisMixin
2729

2830
EPS = np.finfo(float).eps
@@ -36,19 +38,25 @@ class MLKR(MahalanobisMixin, TransformerMixin):
3638
n_iter_ : `int`
3739
The number of iterations the solver has run.
3840
39-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
41+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
4042
The learned linear transformation ``L``.
4143
"""
4244

43-
def __init__(self, num_dims=None, A0=None, tol=None, max_iter=1000,
44-
verbose=False, preprocessor=None):
45+
def __init__(self, n_components=None, num_dims='deprecated', A0=None,
46+
tol=None, max_iter=1000, verbose=False, preprocessor=None):
4547
"""
4648
Initialize MLKR.
4749
4850
Parameters
4951
----------
50-
num_dims : int, optional
51-
Dimensionality of reduced space (defaults to dimension of X)
52+
n_components : int or None, optional (default=None)
53+
Dimensionality of reduced space (if None, defaults to dimension of X).
54+
55+
num_dims : Not used
56+
57+
.. deprecated:: 0.5.0
58+
`num_dims` was deprecated in version 0.5.0 and will
59+
be removed in 0.6.0. Use `n_components` instead.
5260
5361
A0: array-like, optional
5462
Initialization of transformation matrix. Defaults to PCA loadings.
@@ -66,6 +74,7 @@ def __init__(self, num_dims=None, A0=None, tol=None, max_iter=1000,
6674
The preprocessor to call to get tuples from indices. If array-like,
6775
tuples will be formed like this: X[indices].
6876
"""
77+
self.n_components = n_components
6978
self.num_dims = num_dims
7079
self.A0 = A0
7180
self.tol = tol
@@ -82,6 +91,11 @@ def fit(self, X, y):
8291
X : (n x d) array of samples
8392
y : (n) data labels
8493
"""
94+
if self.num_dims != 'deprecated':
95+
warnings.warn('"num_dims" parameter is not used.'
96+
' It has been deprecated in version 0.5.0 and will be'
97+
' removed in 0.6.0. Use "n_components" instead',
98+
DeprecationWarning)
8599
X, y = self._prepare_inputs(X, y, y_numeric=True,
86100
ensure_min_samples=2)
87101
n, d = X.shape
@@ -90,7 +104,8 @@ def fit(self, X, y):
90104
% (n, y.shape[0]))
91105

92106
A = self.A0
93-
m = self.num_dims
107+
m = _check_n_components(d, self.n_components)
108+
m = self.n_components
94109
if m is None:
95110
m = d
96111
if A is None:

metric_learn/mmc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
356356
n_iter_ : `int`
357357
The number of iterations the solver has run.
358358
359-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
359+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
360360
The linear transformation ``L`` deduced from the learned Mahalanobis
361361
metric (See function `transformer_from_metric`.)
362362
@@ -406,7 +406,7 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
406406
n_iter_ : `int`
407407
The number of iterations the solver has run.
408408
409-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
409+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
410410
The linear transformation ``L`` deduced from the learned Mahalanobis
411411
metric (See function `transformer_from_metric`.)
412412
"""
@@ -469,7 +469,7 @@ def fit(self, X, y, random_state=np.random):
469469
if self.num_labeled != 'deprecated':
470470
warnings.warn('"num_labeled" parameter is not used.'
471471
' It has been deprecated in version 0.5.0 and will be'
472-
'removed in 0.6.0', DeprecationWarning)
472+
' removed in 0.6.0', DeprecationWarning)
473473
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
474474
num_constraints = self.num_constraints
475475
if num_constraints is None:

0 commit comments

Comments
 (0)