Skip to content

Commit 17216a7

Browse files
authored
Rename variables, proposed by issue #257 (#324)
* Rename number_constrains to n_constraints * Renamed num_chunks to n_chunks * LMNN k parameter renamed to n_neighbors * Replaced all 'convergence_threshold' with 'tol' * Fix tests * Fixed more test regarding rename of variable * Warnings for n_constrains * Add all warnings regarding n_constrains * Deprecation warnings for n_chunks * Add deprecation warn to n_neighbors * Add convergence_threshold warnings
1 parent 8520418 commit 17216a7

20 files changed

+305
-186
lines changed

bench/benchmarks/iris.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
CLASSES = {
77
'Covariance': metric_learn.Covariance(),
8-
'ITML_Supervised': metric_learn.ITML_Supervised(num_constraints=200),
8+
'ITML_Supervised': metric_learn.ITML_Supervised(n_constraints=200),
99
'LFDA': metric_learn.LFDA(k=2, dim=2),
10-
'LMNN': metric_learn.LMNN(k=5, learn_rate=1e-6, verbose=False),
11-
'LSML_Supervised': metric_learn.LSML_Supervised(num_constraints=200),
10+
'LMNN': metric_learn.LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False),
11+
'LSML_Supervised': metric_learn.LSML_Supervised(n_constraints=200),
1212
'MLKR': metric_learn.MLKR(),
1313
'NCA': metric_learn.NCA(max_iter=700, n_components=2),
14-
'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, num_chunks=30,
14+
'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, n_chunks=30,
1515
chunk_size=2),
16-
'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500)
16+
'SDML_Supervised': metric_learn.SDML_Supervised(n_constraints=1500)
1717
}
1818

1919

doc/supervised.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ indicates :math:`\mathbf{x}_{i}, \mathbf{x}_{j}` belong to different classes,
164164
X = iris_data['data']
165165
Y = iris_data['target']
166166

167-
lmnn = LMNN(k=5, learn_rate=1e-6)
167+
lmnn = LMNN(n_neighbors=5, learn_rate=1e-6)
168168
lmnn.fit(X, Y, verbose=False)
169169

170170
.. rubric:: References
@@ -407,8 +407,8 @@ are similar (+1) or dissimilar (-1)), are sampled with the function
407407
(of label +1), this method will look at all the samples from the same label and
408408
sample randomly a pair among them. To sample negative pairs (of label -1), this
409409
method will look at all the samples from a different class and sample randomly
410-
a pair among them. The method will try to build `num_constraints` positive
411-
pairs and `num_constraints` negative pairs, but sometimes it cannot find enough
410+
a pair among them. The method will try to build `n_constraints` positive
411+
pairs and `n_constraints` negative pairs, but sometimes it cannot find enough
412412
of one of those, so forcing `same_length=True` will return both times the
413413
minimum of the two lenghts.
414414

@@ -430,5 +430,5 @@ last points should be less similar than the two first points).
430430
X = iris_data['data']
431431
Y = iris_data['target']
432432

433-
mmc = MMC_Supervised(num_constraints=200)
433+
mmc = MMC_Supervised(n_constraints=200)
434434
mmc.fit(X, Y)

doc/weakly_supervised.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ are respected.
137137
>>> from metric_learn import MMC
138138
>>> mmc = MMC(random_state=42)
139139
>>> mmc.fit(tuples, y)
140-
MMC(A0='deprecated', convergence_threshold=0.001, diagonal=False,
140+
MMC(A0='deprecated', tol=0.001, diagonal=False,
141141
diagonal_c=1.0, init='auto', max_iter=100, max_proj=10000,
142142
preprocessor=None, random_state=42, verbose=False)
143143

@@ -263,7 +263,7 @@ tuples).
263263
>>> y_pairs = np.array([1, -1])
264264
>>> mmc = MMC(random_state=42)
265265
>>> mmc.fit(pairs, y_pairs)
266-
MMC(convergence_threshold=0.001, diagonal=False,
266+
MMC(tol=0.001, diagonal=False,
267267
diagonal_c=1.0, init='auto', max_iter=100, max_proj=10000, preprocessor=None,
268268
random_state=42, verbose=False)
269269

examples/plot_metric_learning_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def plot_tsne(X, y, colormap=plt.cm.Paired):
143143
#
144144

145145
# setting up LMNN
146-
lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6)
146+
lmnn = metric_learn.LMNN(n_neighbors=5, learn_rate=1e-6)
147147

148148
# fit the data!
149149
lmnn.fit(X, y)
@@ -314,7 +314,7 @@ def plot_tsne(X, y, colormap=plt.cm.Paired):
314314
# - See more in the documentation of the class :py:class:`RCA
315315
# <metric_learn.RCA>`
316316

317-
rca = metric_learn.RCA_Supervised(num_chunks=30, chunk_size=2)
317+
rca = metric_learn.RCA_Supervised(n_chunks=30, chunk_size=2)
318318
X_rca = rca.fit_transform(X, y)
319319

320320
plot_tsne(X_rca, y)

examples/plot_sandwich.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def sandwich_demo():
3535

3636
mls = [
3737
LMNN(),
38-
ITML_Supervised(num_constraints=200),
39-
SDML_Supervised(num_constraints=200, balance_param=0.001),
40-
LSML_Supervised(num_constraints=200),
38+
ITML_Supervised(n_constraints=200),
39+
SDML_Supervised(n_constraints=200, balance_param=0.001),
40+
LSML_Supervised(n_constraints=200),
4141
]
4242

4343
for ax_num, ml in enumerate(mls, start=3):

metric_learn/constraints.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sklearn.utils import check_random_state
88
from sklearn.neighbors import NearestNeighbors
99

10+
1011
__all__ = ['Constraints']
1112

1213

@@ -31,21 +32,21 @@ def __init__(self, partial_labels):
3132
partial_labels = np.asanyarray(partial_labels, dtype=int)
3233
self.partial_labels = partial_labels
3334

34-
def positive_negative_pairs(self, num_constraints, same_length=False,
35-
random_state=None):
35+
def positive_negative_pairs(self, n_constraints, same_length=False,
36+
random_state=None, num_constraints='deprecated'):
3637
"""
3738
Generates positive pairs and negative pairs from labeled data.
3839
39-
Positive pairs are formed by randomly drawing ``num_constraints`` pairs of
40+
Positive pairs are formed by randomly drawing ``n_constraints`` pairs of
4041
points with the same label. Negative pairs are formed by randomly drawing
41-
``num_constraints`` pairs of points with different label.
42+
``n_constraints`` pairs of points with different label.
4243
4344
In the case where it is not possible to generate enough positive or
4445
negative pairs, a smaller number of pairs will be returned with a warning.
4546
4647
Parameters
4748
----------
48-
num_constraints : int
49+
n_constraints : int
4950
Number of positive and negative constraints to generate.
5051
5152
same_length : bool, optional (default=False)
@@ -55,6 +56,8 @@ def positive_negative_pairs(self, num_constraints, same_length=False,
5556
random_state : int or numpy.RandomState or None, optional (default=None)
5657
A pseudo random number generator object or a seed for it if int.
5758
59+
num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0
60+
5861
Returns
5962
-------
6063
a : array-like, shape=(n_constraints,)
@@ -69,10 +72,18 @@ def positive_negative_pairs(self, num_constraints, same_length=False,
6972
d : array-like, shape=(n_constraints,)
7073
1D array of indicators for the right elements of negative pairs.
7174
"""
75+
if num_constraints != 'deprecated':
76+
warnings.warn('"num_constraints" parameter has been renamed to'
77+
' "n_constraints". It has been deprecated in'
78+
' version 0.6.3 and will be removed in 0.7.0'
79+
'', FutureWarning)
80+
self.n_constraints = num_constraints
81+
else:
82+
self.n_constraints = n_constraints
7283
random_state = check_random_state(random_state)
73-
a, b = self._pairs(num_constraints, same_label=True,
84+
a, b = self._pairs(n_constraints, same_label=True,
7485
random_state=random_state)
75-
c, d = self._pairs(num_constraints, same_label=False,
86+
c, d = self._pairs(n_constraints, same_label=False,
7687
random_state=random_state)
7788
if same_length and len(a) != len(c):
7889
n = min(len(a), len(c))
@@ -190,15 +201,15 @@ def generate_knntriplets(self, X, k_genuine, k_impostor):
190201

191202
return triplets
192203

193-
def _pairs(self, num_constraints, same_label=True, max_iter=10,
204+
def _pairs(self, n_constraints, same_label=True, max_iter=10,
194205
random_state=np.random):
195206
known_label_idx, = np.where(self.partial_labels >= 0)
196207
known_labels = self.partial_labels[known_label_idx]
197208
num_labels = len(known_labels)
198209
ab = set()
199210
it = 0
200-
while it < max_iter and len(ab) < num_constraints:
201-
nc = num_constraints - len(ab)
211+
while it < max_iter and len(ab) < n_constraints:
212+
nc = n_constraints - len(ab)
202213
for aidx in random_state.randint(num_labels, size=nc):
203214
if same_label:
204215
mask = known_labels[aidx] == known_labels
@@ -209,25 +220,26 @@ def _pairs(self, num_constraints, same_label=True, max_iter=10,
209220
if len(b_choices) > 0:
210221
ab.add((aidx, random_state.choice(b_choices)))
211222
it += 1
212-
if len(ab) < num_constraints:
223+
if len(ab) < n_constraints:
213224
warnings.warn("Only generated %d %s constraints (requested %d)" % (
214-
len(ab), 'positive' if same_label else 'negative', num_constraints))
215-
ab = np.array(list(ab)[:num_constraints], dtype=int)
225+
len(ab), 'positive' if same_label else 'negative', n_constraints))
226+
ab = np.array(list(ab)[:n_constraints], dtype=int)
216227
return known_label_idx[ab.T]
217228

218-
def chunks(self, num_chunks=100, chunk_size=2, random_state=None):
229+
def chunks(self, n_chunks=100, chunk_size=2, random_state=None,
230+
num_chunks='deprecated'):
219231
"""
220232
Generates chunks from labeled data.
221233
222-
Each of ``num_chunks`` chunks is composed of ``chunk_size`` points from
234+
Each of ``n_chunks`` chunks is composed of ``chunk_size`` points from
223235
the same class drawn at random. Each point can belong to at most 1 chunk.
224236
225-
In the case where there is not enough points to generate ``num_chunks``
237+
In the case where there is not enough points to generate ``n_chunks``
226238
chunks of size ``chunk_size``, a ValueError will be raised.
227239
228240
Parameters
229241
----------
230-
num_chunks : int, optional (default=100)
242+
n_chunks : int, optional (default=100)
231243
Number of chunks to generate.
232244
233245
chunk_size : int, optional (default=2)
@@ -236,26 +248,34 @@ def chunks(self, num_chunks=100, chunk_size=2, random_state=None):
236248
random_state : int or numpy.RandomState or None, optional (default=None)
237249
A pseudo random number generator object or a seed for it if int.
238250
251+
num_chunks : Renamed to n_chunks. Will be deprecated in 0.7.0
252+
239253
Returns
240254
-------
241255
chunks : array-like, shape=(n_samples,)
242256
1D array of chunk indicators, where -1 indicates that the point does not
243257
belong to any chunk.
244258
"""
259+
if num_chunks != 'deprecated':
260+
warnings.warn('"num_chunks" parameter has been renamed to'
261+
' "n_chunks". It has been deprecated in'
262+
' version 0.6.3 and will be removed in 0.7.0'
263+
'', FutureWarning)
264+
n_chunks = num_chunks
245265
random_state = check_random_state(random_state)
246266
chunks = -np.ones_like(self.partial_labels, dtype=int)
247267
uniq, lookup = np.unique(self.partial_labels, return_inverse=True)
248268
unknown_uniq = np.where(uniq < 0)[0]
249269
all_inds = [set(np.where(lookup == c)[0]) for c in range(len(uniq))
250270
if c not in unknown_uniq]
251271
max_chunks = int(np.sum([len(s) // chunk_size for s in all_inds]))
252-
if max_chunks < num_chunks:
272+
if max_chunks < n_chunks:
253273
raise ValueError(('Not enough possible chunks of %d elements in each'
254274
' class to form expected %d chunks - maximum number'
255275
' of chunks is %d'
256-
) % (chunk_size, num_chunks, max_chunks))
276+
) % (chunk_size, n_chunks, max_chunks))
257277
idx = 0
258-
while idx < num_chunks and all_inds:
278+
while idx < n_chunks and all_inds:
259279
if len(all_inds) == 1:
260280
c = 0
261281
else:

metric_learn/itml.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,28 @@
99
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
1010
from .constraints import Constraints, wrap_pairs
1111
from ._util import components_from_metric, _initialize_metric_mahalanobis
12+
import warnings
1213

1314

1415
class _BaseITML(MahalanobisMixin):
1516
"""Information Theoretic Metric Learning (ITML)"""
1617

1718
_tuple_size = 2 # constraints are pairs
1819

19-
def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
20+
def __init__(self, gamma=1., max_iter=1000, tol=1e-3,
2021
prior='identity', verbose=False,
21-
preprocessor=None, random_state=None):
22+
preprocessor=None, random_state=None,
23+
convergence_threshold='deprecated'):
24+
if convergence_threshold != 'deprecated':
25+
warnings.warn('"convergence_threshold" parameter has been '
26+
' renamed to "tol". It has been deprecated in'
27+
' version 0.6.3 and will be removed in 0.7.0'
28+
'', FutureWarning)
29+
tol = convergence_threshold
30+
self.convergence_threshold = 'deprecated' # Avoid errors
2231
self.gamma = gamma
2332
self.max_iter = max_iter
24-
self.convergence_threshold = convergence_threshold
33+
self.tol = tol
2534
self.prior = prior
2635
self.verbose = verbose
2736
self.random_state = random_state
@@ -86,7 +95,7 @@ def _fit(self, pairs, y, bounds=None):
8695
conv = np.inf
8796
break
8897
conv = np.abs(lambdaold - _lambda).sum() / normsum
89-
if conv < self.convergence_threshold:
98+
if conv < self.tol:
9099
break
91100
lambdaold = _lambda.copy()
92101
if self.verbose:
@@ -122,7 +131,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
122131
max_iter : int, optional (default=1000)
123132
Maximum number of iteration of the optimization procedure.
124133
125-
convergence_threshold : float, optional (default=1e-3)
134+
tol : float, optional (default=1e-3)
126135
Convergence tolerance.
127136
128137
prior : string or numpy array, optional (default='identity')
@@ -158,6 +167,8 @@ class ITML(_BaseITML, _PairsClassifierMixin):
158167
A pseudo random number generator object or a seed for it if int. If
159168
``prior='random'``, ``random_state`` is used to set the prior.
160169
170+
convergence_threshold : Renamed to tol. Will be deprecated in 0.7.0
171+
161172
Attributes
162173
----------
163174
bounds_ : `numpy.ndarray`, shape=(2,)
@@ -260,10 +271,10 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
260271
max_iter : int, optional (default=1000)
261272
Maximum number of iterations of the optimization procedure.
262273
263-
convergence_threshold : float, optional (default=1e-3)
274+
tol : float, optional (default=1e-3)
264275
Tolerance of the optimization procedure.
265276
266-
num_constraints : int, optional (default=None)
277+
n_constraints : int, optional (default=None)
267278
Number of constraints to generate. If None, default to `20 *
268279
num_classes**2`.
269280
@@ -302,6 +313,9 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
302313
case, `random_state` is also used to randomly sample constraints from
303314
labels.
304315
316+
num_constraints : Renamed to n_constraints. Will be deprecated in 0.7.0
317+
318+
convergence_threshold : Renamed to tol. Will be deprecated in 0.7.0
305319
306320
Attributes
307321
----------
@@ -328,7 +342,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
328342
>>> iris_data = load_iris()
329343
>>> X = iris_data['data']
330344
>>> Y = iris_data['target']
331-
>>> itml = ITML_Supervised(num_constraints=200)
345+
>>> itml = ITML_Supervised(n_constraints=200)
332346
>>> itml.fit(X, Y)
333347
334348
See Also
@@ -338,14 +352,26 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
338352
that describes the supervised version of weakly supervised estimators.
339353
"""
340354

341-
def __init__(self, gamma=1.0, max_iter=1000, convergence_threshold=1e-3,
342-
num_constraints=None, prior='identity',
343-
verbose=False, preprocessor=None, random_state=None):
355+
def __init__(self, gamma=1.0, max_iter=1000, tol=1e-3,
356+
n_constraints=None, prior='identity',
357+
verbose=False, preprocessor=None, random_state=None,
358+
num_constraints='deprecated',
359+
convergence_threshold='deprecated'):
344360
_BaseITML.__init__(self, gamma=gamma, max_iter=max_iter,
345-
convergence_threshold=convergence_threshold,
361+
tol=tol,
346362
prior=prior, verbose=verbose,
347-
preprocessor=preprocessor, random_state=random_state)
348-
self.num_constraints = num_constraints
363+
preprocessor=preprocessor,
364+
random_state=random_state,
365+
convergence_threshold=convergence_threshold)
366+
if num_constraints != 'deprecated':
367+
warnings.warn('"num_constraints" parameter has been renamed to'
368+
' "n_constraints". It has been deprecated in'
369+
' version 0.6.3 and will be removed in 0.7.0'
370+
'', FutureWarning)
371+
n_constraints = num_constraints
372+
self.n_constraints = n_constraints
373+
# Avoid test get_params from failing (all params passed sholud be set)
374+
self.num_constraints = 'deprecated'
349375

350376
def fit(self, X, y, bounds=None):
351377
"""Create constraints from labels and learn the ITML model.
@@ -369,13 +395,13 @@ def fit(self, X, y, bounds=None):
369395
points in the training data `X`.
370396
"""
371397
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
372-
num_constraints = self.num_constraints
373-
if num_constraints is None:
398+
n_constraints = self.n_constraints
399+
if n_constraints is None:
374400
num_classes = len(np.unique(y))
375-
num_constraints = 20 * num_classes**2
401+
n_constraints = 20 * num_classes**2
376402

377403
c = Constraints(y)
378-
pos_neg = c.positive_negative_pairs(num_constraints,
404+
pos_neg = c.positive_negative_pairs(n_constraints,
379405
random_state=self.random_state)
380406
pairs, y = wrap_pairs(X, pos_neg)
381407
return _BaseITML._fit(self, pairs, y, bounds=bounds)

0 commit comments

Comments
 (0)