Skip to content

Commit caba555

Browse files
committed
SCML: Raise ValueError if n_features larger than n_triplets
1 parent 72b76c8 commit caba555

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

metric_learn/scml.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ def _generate_bases_dist_diff(self, triplets, X):
240240
raise ValueError("n_basis should be an integer, instead it is of type %s"
241241
% type(self.n_basis))
242242

243+
if n_features > n_triplets:
244+
raise ValueError("Number of features (%s) is greater than the nuber of triplets(%s)."
245+
"\nConsider using a dimensionality reduction preprocessing or create "
246+
"a new basis generation scheme." % (n_features, n_triplets))
247+
243248
basis = np.zeros((n_basis, n_features))
244249

245250
# get all positive and negative pairs with lowest index first

test/test_triplets_classifiers.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from sklearn.exceptions import NotFittedError
33
from sklearn.model_selection import train_test_split
44

5-
from test.test_utils import triplets_learners, ids_triplets_learners
5+
from metric_learn import SCML
6+
from test.test_utils import triplets_learners, ids_triplets_learners, build_triplets
67
from metric_learn.sklearn_shims import set_random_state
78
from sklearn import clone
89
import numpy as np
@@ -107,3 +108,16 @@ def test_accuracy_toy_example(estimator, build_dataset):
107108
# we force the transformation to be identity so that we control what it does
108109
estimator.components_ = np.eye(X.shape[1])
109110
assert estimator.score(triplets_test) == 0.25
111+
112+
113+
def test_raise_big_number_of_features():
114+
triplets, _, _, X = build_triplets(with_preprocessor=False)
115+
triplets = triplets[:3, :, :]
116+
estimator = SCML(n_basis=320)
117+
set_random_state(estimator)
118+
with pytest.raises(ValueError) as exc_info:
119+
estimator.fit(triplets)
120+
assert exc_info.value.args[0] == \
121+
"Number of features (4) is greater than the nuber of triplets(3).\n" \
122+
"Consider using a dimensionality reduction preprocessing or create " \
123+
"a new basis generation scheme."

0 commit comments

Comments
 (0)