Skip to content

Commit abf79d1

Browse files
author
mvargas33
committed
LMNN k parameter renamed to n_neighbors
1 parent 1797c22 commit abf79d1

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

metric_learn/lmnn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class LMNN(MahalanobisMixin, TransformerMixin):
6363
:meth:`fit` and n_features_a must be less than or equal to that.
6464
If ``n_components`` is not None, n_features_a must match it.
6565
66-
k : int, optional (default=3)
66+
n_neighbors : int, optional (default=3)
6767
Number of neighbors to consider, not including self-edges.
6868
6969
min_iter : int, optional (default=50)
@@ -128,12 +128,12 @@ class LMNN(MahalanobisMixin, TransformerMixin):
128128
2005.
129129
"""
130130

131-
def __init__(self, init='auto', k=3, min_iter=50, max_iter=1000,
131+
def __init__(self, init='auto', n_neighbors=3, min_iter=50, max_iter=1000,
132132
learn_rate=1e-7, regularization=0.5, convergence_tol=0.001,
133133
verbose=False, preprocessor=None,
134134
n_components=None, random_state=None):
135135
self.init = init
136-
self.k = k
136+
self.n_neighbors = n_neighbors
137137
self.min_iter = min_iter
138138
self.max_iter = max_iter
139139
self.learn_rate = learn_rate
@@ -145,7 +145,7 @@ def __init__(self, init='auto', k=3, min_iter=50, max_iter=1000,
145145
super(LMNN, self).__init__(preprocessor)
146146

147147
def fit(self, X, y):
148-
k = self.k
148+
k = self.n_neighbors
149149
reg = self.regularization
150150
learn_rate = self.learn_rate
151151

@@ -162,7 +162,7 @@ def fit(self, X, y):
162162
self.verbose,
163163
random_state=self.random_state)
164164
required_k = np.bincount(label_inds).min()
165-
if self.k > required_k:
165+
if self.n_neighbors > required_k:
166166
raise ValueError('not enough class labels for specified k'
167167
' (smallest class has %d)' % required_k)
168168

@@ -275,12 +275,12 @@ def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds):
275275
return 2 * G, objective, total_active
276276

277277
def _select_targets(self, X, label_inds):
278-
target_neighbors = np.empty((X.shape[0], self.k), dtype=int)
278+
target_neighbors = np.empty((X.shape[0], self.n_neighbors), dtype=int)
279279
for label in self.labels_:
280280
inds, = np.nonzero(label_inds == label)
281281
dd = euclidean_distances(X[inds], squared=True)
282282
np.fill_diagonal(dd, np.inf)
283-
nn = np.argsort(dd)[..., :self.k]
283+
nn = np.argsort(dd)[..., :self.n_neighbors]
284284
target_neighbors[inds] = inds[nn]
285285
return target_neighbors
286286

0 commit comments

Comments
 (0)