Skip to content

Commit 4b889d4

Browse files
toto6perimosocordiae
authored andcommitted
LMNN: fix mistake and improve performances (#78)
Fix mistake in LMNN Issue in function _find_impostors: - the squared euclidean distance is used to compute the margins in variable "margin_radii" - the euclidean distance is used (through the function sklearn.metrics.pairwise.pairwise_distances) to compute distances between samples of different labels in variable "dist" - the issue is that the impostors are found by testing "dist < margin_radii" which is wrong because "dist" represent distances, and "margin_radii" represent squared distances. I propose to solve this problem by computing always the squared distances.
1 parent 84dbcbe commit 4b889d4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

metric_learn/lmnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import warnings
1515
from collections import Counter
1616
from six.moves import xrange
17-
from sklearn.metrics import pairwise_distances
1817
from sklearn.utils.validation import check_X_y, check_array
18+
from sklearn.metrics import euclidean_distances
1919

2020
from .base_metric import BaseMetricLearner
2121

@@ -185,7 +185,7 @@ def _select_targets(self):
185185
target_neighbors = np.empty((self.X_.shape[0], self.k), dtype=int)
186186
for label in self.labels_:
187187
inds, = np.nonzero(self.label_inds_ == label)
188-
dd = pairwise_distances(self.X_[inds])
188+
dd = euclidean_distances(self.X_[inds], squared=True)
189189
np.fill_diagonal(dd, np.inf)
190190
nn = np.argsort(dd)[..., :self.k]
191191
target_neighbors[inds] = inds[nn]
@@ -198,7 +198,7 @@ def _find_impostors(self, furthest_neighbors):
198198
for label in self.labels_[:-1]:
199199
in_inds, = np.nonzero(self.label_inds_ == label)
200200
out_inds, = np.nonzero(self.label_inds_ > label)
201-
dist = pairwise_distances(Lx[out_inds], Lx[in_inds])
201+
dist = euclidean_distances(Lx[out_inds], Lx[in_inds], squared=True)
202202
i1,j1 = np.nonzero(dist < margin_radii[out_inds][:,None])
203203
i2,j2 = np.nonzero(dist < margin_radii[in_inds])
204204
i = np.hstack((i1,i2))

0 commit comments

Comments
 (0)