Skip to content

Commit 047191b

Browse files
author
William de Vazelhes
committed
Refactor num_dims in n_components and add deprecation
1 parent 5edad14 commit 047191b

19 files changed

+267
-116
lines changed

metric_learn/_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,11 @@ def validate_vector(u, dtype=None):
407407
return u
408408

409409

410-
def _check_num_dims(n_features, num_dims):
411-
"""Checks that num_dims is less than n_features and deal with the None
410+
def _check_n_components(n_features, n_components):
411+
"""Checks that n_components is less than n_features and deal with the None
412412
case"""
413-
if num_dims is None:
413+
if n_components is None:
414414
return n_features
415-
if 0 < num_dims <= n_features:
415+
if 0 < n_components <= n_features:
416416
return n_features
417-
raise ValueError('Invalid num_dims, must be in [1, %d]' % n_features)
417+
raise ValueError('Invalid n_components, must be in [1, %d]' % n_features)

metric_learn/base_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner,
172172
173173
Attributes
174174
----------
175-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
175+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
176176
The learned linear transformation ``L``.
177177
"""
178178

@@ -232,7 +232,7 @@ def transform(self, X):
232232
233233
Returns
234234
-------
235-
X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims)
235+
X_embedded : `numpy.ndarray`, shape=(n_samples, n_components)
236236
The embedded data points.
237237
"""
238238
X_checked = check_input(X, type_of_inputs='classic', estimator=self,

metric_learn/covariance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Covariance(MahalanobisMixin, TransformerMixin):
2121
2222
Attributes
2323
----------
24-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
24+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
2525
The linear transformation ``L`` deduced from the learned Mahalanobis
2626
metric (See function `transformer_from_metric`.)
2727
"""

metric_learn/itml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
145145
n_iter_ : `int`
146146
The number of iterations the solver has run.
147147
148-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
148+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
149149
The linear transformation ``L`` deduced from the learned Mahalanobis
150150
metric (See function `transformer_from_metric`.)
151151
@@ -213,7 +213,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
213213
n_iter_ : `int`
214214
The number of iterations the solver has run.
215215
216-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
216+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
217217
The linear transformation ``L`` deduced from the learned Mahalanobis
218218
metric (See function `transformer_from_metric`.)
219219
"""

metric_learn/lfda.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sklearn.metrics import pairwise_distances
1919
from sklearn.base import TransformerMixin
2020

21-
from ._util import _check_num_dims
21+
from ._util import _check_n_components
2222
from .base_metric import MahalanobisMixin
2323

2424

@@ -29,23 +29,29 @@ class LFDA(MahalanobisMixin, TransformerMixin):
2929
3030
Attributes
3131
----------
32-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
32+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
3333
The learned linear transformation ``L``.
3434
'''
3535

36-
def __init__(self, num_dims=None, k=None, embedding_type='weighted',
37-
preprocessor=None):
36+
def __init__(self, n_components=None, num_dims='deprecated',
37+
k=None, embedding_type='weighted', preprocessor=None):
3838
'''
3939
Initialize LFDA.
4040
4141
Parameters
4242
----------
43-
num_dims : int, optional
44-
Dimensionality of reduced space (defaults to dimension of X)
43+
n_components : int or None, optional (default=None)
44+
Dimensionality of reduced space (if None, defaults to dimension of X).
45+
46+
num_dims : Not used
47+
48+
.. deprecated:: 0.5.0
49+
`num_dims` was deprecated in version 0.5.0 and will
50+
be removed in 0.6.0. Use `n_components` instead.
4551
4652
k : int, optional
4753
Number of nearest neighbors used in local scaling method.
48-
Defaults to min(7, num_dims - 1).
54+
Defaults to min(7, n_components - 1).
4955
5056
embedding_type : str, optional
5157
Type of metric in the embedding space (default: 'weighted')
@@ -59,6 +65,7 @@ def __init__(self, num_dims=None, k=None, embedding_type='weighted',
5965
'''
6066
if embedding_type not in ('weighted', 'orthonormalized', 'plain'):
6167
raise ValueError('Invalid embedding_type: %r' % embedding_type)
68+
self.n_components = n_components
6269
self.num_dims = num_dims
6370
self.embedding_type = embedding_type
6471
self.k = k
@@ -75,12 +82,17 @@ def fit(self, X, y):
7582
y : (n,) array-like
7683
Class labels, one per point of data.
7784
'''
85+
if self.num_dims != 'deprecated':
86+
warnings.warn('"num_dims" parameter is not used.'
87+
' It has been deprecated in version 0.5.0 and will be'
88+
'removed in 0.6.0. Use "n_components" instead',
89+
DeprecationWarning)
7890
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
7991
unique_classes, y = np.unique(y, return_inverse=True)
8092
n, d = X.shape
8193
num_classes = len(unique_classes)
8294

83-
dim = _check_num_dims(d, self.num_dims)
95+
dim = _check_n_components(d, self.n_components)
8496

8597
if self.k is None:
8698
k = min(7, d - 1)

metric_learn/lmnn.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@
1717
from sklearn.metrics import euclidean_distances
1818
from sklearn.base import TransformerMixin
1919

20-
from ._util import _check_num_dims
20+
from ._util import _check_n_components
2121
from .base_metric import MahalanobisMixin
2222

2323

2424
# commonality between LMNN implementations
2525
class _base_LMNN(MahalanobisMixin, TransformerMixin):
2626
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
2727
regularization=0.5, convergence_tol=0.001, use_pca=True,
28-
verbose=False, preprocessor=None, num_dims=None):
28+
verbose=False, preprocessor=None, n_components=None,
29+
num_dims='deprecated'):
2930
"""Initialize the LMNN object.
3031
3132
Parameters
@@ -39,6 +40,15 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
3940
preprocessor : array-like, shape=(n_samples, n_features) or callable
4041
The preprocessor to call to get tuples from indices. If array-like,
4142
tuples will be formed like this: X[indices].
43+
44+
n_components : int or None, optional (default=None)
45+
Dimensionality of reduced space (if None, defaults to dimension of X).
46+
47+
num_dims : Not used
48+
49+
.. deprecated:: 0.5.0
50+
`num_dims` was deprecated in version 0.5.0 and will
51+
be removed in 0.6.0. Use `n_components` instead.
4252
"""
4353
self.k = k
4454
self.min_iter = min_iter
@@ -48,6 +58,7 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
4858
self.convergence_tol = convergence_tol
4959
self.use_pca = use_pca
5060
self.verbose = verbose
61+
self.n_components = n_components
5162
self.num_dims = num_dims
5263
super(_base_LMNN, self).__init__(preprocessor)
5364

@@ -56,21 +67,26 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
5667
class python_LMNN(_base_LMNN):
5768

5869
def fit(self, X, y):
70+
if self.num_dims != 'deprecated':
71+
warnings.warn('"num_dims" parameter is not used.'
72+
' It has been deprecated in version 0.5.0 and will be'
73+
'removed in 0.6.0. Use "n_components" instead',
74+
DeprecationWarning)
5975
k = self.k
6076
reg = self.regularization
6177
learn_rate = self.learn_rate
6278

6379
X, y = self._prepare_inputs(X, y, dtype=float,
6480
ensure_min_samples=2)
65-
num_pts, num_dims = X.shape
66-
output_dim = _check_num_dims(num_dims, self.num_dims)
81+
num_pts, d = X.shape
82+
output_dim = _check_n_components(d, self.n_components)
6783
unique_labels, label_inds = np.unique(y, return_inverse=True)
6884
if len(label_inds) != num_pts:
6985
raise ValueError('Must have one label per point.')
7086
self.labels_ = np.arange(len(unique_labels))
7187
if self.use_pca:
7288
warnings.warn('use_pca does nothing for the python_LMNN implementation')
73-
self.transformer_ = np.eye(output_dim, num_dims)
89+
self.transformer_ = np.eye(output_dim, d)
7490
required_k = np.bincount(label_inds).min()
7591
if self.k > required_k:
7692
raise ValueError('not enough class labels for specified k'
@@ -272,7 +288,7 @@ class LMNN(_base_LMNN):
272288
n_iter_ : `int`
273289
The number of iterations the solver has run.
274290
275-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
291+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
276292
The learned linear transformation ``L``.
277293
"""
278294

metric_learn/lsml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
139139
n_iter_ : `int`
140140
The number of iterations the solver has run.
141141
142-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
142+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
143143
The linear transformation ``L`` deduced from the learned Mahalanobis
144144
metric (See function `transformer_from_metric`.)
145145
"""
@@ -175,7 +175,7 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
175175
n_iter_ : `int`
176176
The number of iterations the solver has run.
177177
178-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
178+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
179179
The linear transformation ``L`` deduced from the learned Mahalanobis
180180
metric (See function `transformer_from_metric`.)
181181
"""

metric_learn/mlkr.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,25 @@ class MLKR(MahalanobisMixin, TransformerMixin):
3333
n_iter_ : `int`
3434
The number of iterations the solver has run.
3535
36-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
36+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
3737
The learned linear transformation ``L``.
3838
"""
3939

40-
def __init__(self, num_dims=None, A0=None, tol=None, max_iter=1000,
41-
verbose=False, preprocessor=None):
40+
def __init__(self, n_components=None, num_dims='deprecated', A0=None,
41+
tol=None, max_iter=1000, verbose=False, preprocessor=None):
4242
"""
4343
Initialize MLKR.
4444
4545
Parameters
4646
----------
47-
num_dims : int, optional
48-
Dimensionality of reduced space (defaults to dimension of X)
47+
n_components : int or None, optional (default=None)
48+
Dimensionality of reduced space (if None, defaults to dimension of X).
49+
50+
num_dims : Not used
51+
52+
.. deprecated:: 0.5.0
53+
`num_dims` was deprecated in version 0.5.0 and will
54+
be removed in 0.6.0. Use `n_components` instead.
4955
5056
A0: array-like, optional
5157
Initialization of transformation matrix. Defaults to PCA loadings.
@@ -63,6 +69,7 @@ def __init__(self, num_dims=None, A0=None, tol=None, max_iter=1000,
6369
The preprocessor to call to get tuples from indices. If array-like,
6470
tuples will be formed like this: X[indices].
6571
"""
72+
self.n_components = n_components
6673
self.num_dims = num_dims
6774
self.A0 = A0
6875
self.tol = tol
@@ -79,6 +86,11 @@ def fit(self, X, y):
7986
X : (n x d) array of samples
8087
y : (n) data labels
8188
"""
89+
if self.num_dims != 'deprecated':
90+
warnings.warn('"num_dims" parameter is not used.'
91+
' It has been deprecated in version 0.5.0 and will be'
92+
'removed in 0.6.0. Use "n_components" instead',
93+
DeprecationWarning)
8294
X, y = self._prepare_inputs(X, y, y_numeric=True,
8395
ensure_min_samples=2)
8496
n, d = X.shape
@@ -87,7 +99,7 @@ def fit(self, X, y):
8799
% (n, y.shape[0]))
88100

89101
A = self.A0
90-
m = self.num_dims
102+
m = self.n_components
91103
if m is None:
92104
m = d
93105
if A is None:

metric_learn/mmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
356356
n_iter_ : `int`
357357
The number of iterations the solver has run.
358358
359-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
359+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
360360
The linear transformation ``L`` deduced from the learned Mahalanobis
361361
metric (See function `transformer_from_metric`.)
362362
@@ -406,7 +406,7 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
406406
n_iter_ : `int`
407407
The number of iterations the solver has run.
408408
409-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
409+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
410410
The linear transformation ``L`` deduced from the learned Mahalanobis
411411
metric (See function `transformer_from_metric`.)
412412
"""

metric_learn/nca.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sklearn.utils.fixes import logsumexp
1515
from sklearn.base import TransformerMixin
1616

17-
from ._util import _check_num_dims
17+
from ._util import _check_n_components
1818
from .base_metric import MahalanobisMixin
1919

2020
EPS = np.finfo(float).eps
@@ -28,19 +28,24 @@ class NCA(MahalanobisMixin, TransformerMixin):
2828
n_iter_ : `int`
2929
The number of iterations the solver has run.
3030
31-
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
31+
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
3232
The learned linear transformation ``L``.
3333
"""
3434

35-
def __init__(self, num_dims=None, max_iter=100, tol=None, verbose=False,
36-
preprocessor=None):
35+
def __init__(self, n_components=None, num_dims='deprecated', max_iter=100,
36+
tol=None, verbose=False, preprocessor=None):
3737
"""Neighborhood Components Analysis
3838
3939
Parameters
4040
----------
41-
num_dims : int, optional (default=None)
42-
Embedding dimensionality. If None, will be set to ``n_features``
43-
(``d``) at fit time.
41+
n_components : int or None, optional (default=None)
42+
Dimensionality of reduced space (if None, defaults to dimension of X).
43+
44+
num_dims : Not used
45+
46+
.. deprecated:: 0.5.0
47+
`num_dims` was deprecated in version 0.5.0 and will
48+
be removed in 0.6.0. Use `n_components` instead.
4449
4550
max_iter : int, optional (default=100)
4651
Maximum number of iterations done by the optimization algorithm.
@@ -51,6 +56,7 @@ def __init__(self, num_dims=None, max_iter=100, tol=None, verbose=False,
5156
verbose : bool, optional (default=False)
5257
Whether to print progress messages or not.
5358
"""
59+
self.n_components = n_components
5460
self.num_dims = num_dims
5561
self.max_iter = max_iter
5662
self.tol = tol
@@ -62,16 +68,21 @@ def fit(self, X, y):
6268
X: data matrix, (n x d)
6369
y: scalar labels, (n)
6470
"""
71+
if self.num_dims != 'deprecated':
72+
warnings.warn('"num_dims" parameter is not used.'
73+
' It has been deprecated in version 0.5.0 and will be'
74+
'removed in 0.6.0. Use "n_components" instead',
75+
DeprecationWarning)
6576
X, labels = self._prepare_inputs(X, y, ensure_min_samples=2)
6677
n, d = X.shape
67-
num_dims = _check_num_dims(d, self.num_dims)
78+
n_components = _check_n_components(d, self.n_components)
6879

6980
# Measure the total training time
7081
train_time = time.time()
7182

7283
# Initialize A to a scaling matrix
73-
A = np.zeros((num_dims, d))
74-
np.fill_diagonal(A, 1./(np.maximum(X.max(axis=0)-X.min(axis=0), EPS)))
84+
A = np.zeros((n_components, d))
85+
np.fill_diagonal(A, 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS)))
7586

7687
# Run NCA
7788
mask = labels[:, np.newaxis] == labels[np.newaxis, :]
@@ -120,7 +131,7 @@ def _loss_grad_lbfgs(self, A, X, mask, sign=1.0):
120131
start_time = time.time()
121132

122133
A = A.reshape(-1, X.shape[1])
123-
X_embedded = np.dot(X, A.T) # (n_samples, num_dims)
134+
X_embedded = np.dot(X, A.T) # (n_samples, n_components)
124135
# Compute softmax distances
125136
p_ij = pairwise_distances(X_embedded, squared=True)
126137
np.fill_diagonal(p_ij, np.inf)

0 commit comments

Comments
 (0)