Skip to content

Commit 8520418

Browse files
authored
SCML: Raise ValueError if n_features larger than n_triplets (#350)
1 parent d78c720 commit 8520418

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

metric_learn/scml.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,12 @@ def _generate_bases_dist_diff(self, triplets, X):
240240
raise ValueError("n_basis should be an integer, instead it is of type %s"
241241
% type(self.n_basis))
242242

243+
if n_features > n_triplets:
244+
raise ValueError(
245+
"Number of features (%s) is greater than the number of triplets(%s).\n"
246+
"Consider using dimensionality reduction or using another basis "
247+
"generation scheme." % (n_features, n_triplets))
248+
243249
basis = np.zeros((n_basis, n_features))
244250

245251
# get all positive and negative pairs with lowest index first

test/test_triplets_classifiers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
from sklearn.exceptions import NotFittedError
33
from sklearn.model_selection import train_test_split
44

5-
from test.test_utils import triplets_learners, ids_triplets_learners
5+
from metric_learn import SCML
6+
from test.test_utils import (
7+
triplets_learners,
8+
ids_triplets_learners,
9+
build_triplets
10+
)
611
from metric_learn.sklearn_shims import set_random_state
712
from sklearn import clone
813
import numpy as np
@@ -107,3 +112,16 @@ def test_accuracy_toy_example(estimator, build_dataset):
107112
# we force the transformation to be identity so that we control what it does
108113
estimator.components_ = np.eye(X.shape[1])
109114
assert estimator.score(triplets_test) == 0.25
115+
116+
117+
def test_raise_big_number_of_features():
118+
triplets, _, _, X = build_triplets(with_preprocessor=False)
119+
triplets = triplets[:3, :, :]
120+
estimator = SCML(n_basis=320)
121+
set_random_state(estimator)
122+
with pytest.raises(ValueError) as exc_info:
123+
estimator.fit(triplets)
124+
assert exc_info.value.args[0] == \
125+
"Number of features (4) is greater than the number of triplets(3)." \
126+
"\nConsider using dimensionality reduction or using another basis " \
127+
"generation scheme."

0 commit comments

Comments
 (0)