diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 47bb065f..7865cbe5 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -172,7 +172,7 @@ def fit(self, X, y): self.verbose, random_state=self.random_state) required_k = np.bincount(label_inds).min() - if self.n_neighbors > required_k: + if self.n_neighbors >= required_k: raise ValueError('not enough class labels for specified k' ' (smallest class has %d)' % required_k)