diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 757d1be5..e580f3ed 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -63,6 +63,9 @@ def fit(self, X, labels): target_neighbors = self._select_targets() impostors = self._find_impostors(target_neighbors[:,-1]) + if len(impostors) == 0: + # L has already been initialized to an identity matrix of requisite shape + return # sum outer products dfG = _sum_outer_products(self.X, target_neighbors.flatten(), @@ -203,6 +206,9 @@ def _find_impostors(self, furthest_neighbors): tmp = np.ravel_multi_index((i,j), shape) i,j = np.unravel_index(np.unique(tmp), shape) impostors.append(np.vstack((in_inds[j], out_inds[i]))) + if len(impostors) == 0: + # No impostors detected + return impostors return np.hstack(impostors)