diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 542e1e0a..fe1560c2 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -95,7 +95,7 @@ def test_big_n_features(self): n_informative=60, n_redundant=0, n_repeated=0, random_state=42) X = StandardScaler().fit_transform(X) - scml = SCML_Supervised(random_state=42) + scml = SCML_Supervised(random_state=42, n_basis=399) scml.fit(X, y) csep = class_separation(scml.transform(X), y) assert csep < 0.7 @@ -106,7 +106,7 @@ def test_big_n_features(self): [2, 0], [2, 1]]), np.array([1, 0, 1, 0])))]) def test_bad_basis(self, estimator, data): - model = estimator(basis='bad_basis') + model = estimator(basis='bad_basis', n_basis=33) # n_basis doesn't matter msg = ("`basis` must be one of the options '{}' or an array of shape " "(n_basis, n_features)." .format("', '".join(model._authorized_basis))) @@ -238,16 +238,23 @@ def test_lda_toy(self): @pytest.mark.parametrize('n_features', [10, 50, 100]) @pytest.mark.parametrize('n_classes', [5, 10, 15]) def test_triplet_diffs(self, n_samples, n_features, n_classes): + """ + Test that the correct value of n_basis is being generated with + different triplet constraints. + """ X, y = make_classification(n_samples=n_samples, n_classes=n_classes, n_features=n_features, n_informative=n_features, n_redundant=0, n_repeated=0) X = StandardScaler().fit_transform(X) - - model = SCML_Supervised() + model = SCML_Supervised(n_basis=None) # Explicit n_basis=None constraints = Constraints(y) triplets = constraints.generate_knntriplets(X, model.k_genuine, model.k_impostor) - basis, n_basis = model._generate_bases_dist_diff(triplets, X) + + msg = "As no value for `n_basis` was selected, " + with pytest.warns(UserWarning) as raised_warning: + basis, n_basis = model._generate_bases_dist_diff(triplets, X) + assert msg in str(raised_warning[0].message) expected_n_basis = n_features * 80 assert n_basis == expected_n_basis @@ -257,13 +264,21 @@ def test_triplet_diffs(self, n_samples, n_features, n_classes): @pytest.mark.parametrize('n_features', [10, 50, 100]) @pytest.mark.parametrize('n_classes', [5, 10, 15]) def test_lda(self, n_samples, n_features, n_classes): + """ + Test that when n_basis=None, the correct n_basis is generated, + for SCML_Supervised and different values of n_samples, n_features + and n_classes. + """ X, y = make_classification(n_samples=n_samples, n_classes=n_classes, n_features=n_features, n_informative=n_features, n_redundant=0, n_repeated=0) X = StandardScaler().fit_transform(X) - model = SCML_Supervised() - basis, n_basis = model._generate_bases_LDA(X, y) + msg = "As no value for `n_basis` was selected, " + with pytest.warns(UserWarning) as raised_warning: + model = SCML_Supervised(n_basis=None) # Explicit n_basis=None + basis, n_basis = model._generate_bases_LDA(X, y) + assert msg in str(raised_warning[0].message) num_eig = min(n_classes - 1, n_features) expected_n_basis = min(20 * n_features, n_samples * 2 * num_eig - 1) @@ -299,7 +314,7 @@ def test_int_inputs_supervised(self, name): assert msg == raised_error.value.args[0] def test_large_output_iter(self): - scml = SCML(max_iter=1, output_iter=2) + scml = SCML(max_iter=1, output_iter=2, n_basis=33) # n_basis don't matter triplets = np.array([[[0, 1], [2, 1], [0, 0]]]) msg = ("The value of output_iter must be equal or smaller than" " max_iter.") diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index e2aa1e4d..e69aa032 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -291,8 +291,8 @@ def test_components_is_2D(estimator, build_dataset): model.fit(*remove_y(estimator, input_data, labels)) assert model.components_.shape == (X.shape[1], X.shape[1]) - # test that it works for 1 feature - trunc_data = input_data[..., :1] + # test that it works for 1 feature. Use 2nd dimention, to avoid border cases + trunc_data = input_data[..., 1:2] # we drop duplicates that might have been formed, i.e. of the form # aabc or abcc or aabb for quadruplets, and aa for pairs. diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index a23a88d0..d2369b1c 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -79,7 +79,10 @@ def test_rca(self): check_estimator(Stable_RCA_Supervised()) def test_scml(self): - check_estimator(SCML_Supervised()) + msg = "As no value for `n_basis` was selected, " + with pytest.warns(UserWarning) as raised_warning: + check_estimator(SCML_Supervised()) + assert msg in str(raised_warning[0].message) RNG = check_random_state(0) diff --git a/test/test_triplets_classifiers.py b/test/test_triplets_classifiers.py index 600947e6..f2d5c015 100644 --- a/test/test_triplets_classifiers.py +++ b/test/test_triplets_classifiers.py @@ -1,7 +1,6 @@ import pytest from sklearn.exceptions import NotFittedError from sklearn.model_selection import train_test_split -import metric_learn from test.test_utils import triplets_learners, ids_triplets_learners from metric_learn.sklearn_shims import set_random_state @@ -21,13 +20,7 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset, estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) triplets_train, triplets_test = train_test_split(input_data) - if isinstance(estimator, metric_learn.SCML): - msg = "As no value for `n_basis` was selected, " - with pytest.warns(UserWarning) as raised_warning: - estimator.fit(triplets_train) - assert msg in str(raised_warning[0].message) - else: - estimator.fit(triplets_train) + estimator.fit(triplets_train) predictions = estimator.predict(triplets_test) not_valid = [e for e in predictions if e not in [-1, 1]] @@ -49,13 +42,7 @@ def test_no_zero_prediction(estimator, build_dataset): # Dummy fit estimator = clone(estimator) set_random_state(estimator) - if isinstance(estimator, metric_learn.SCML): - msg = "As no value for `n_basis` was selected, " - with pytest.warns(UserWarning) as raised_warning: - estimator.fit(triplets) - assert msg in str(raised_warning[0].message) - else: - estimator.fit(triplets) + estimator.fit(triplets) # We force the transformation to be identity, to force euclidean distance estimator.components_ = np.eye(X.shape[1]) @@ -106,13 +93,7 @@ def test_accuracy_toy_example(estimator, build_dataset): triplets, _, _, X = build_dataset(with_preprocessor=False) estimator = clone(estimator) set_random_state(estimator) - if isinstance(estimator, metric_learn.SCML): - msg = "As no value for `n_basis` was selected, " - with pytest.warns(UserWarning) as raised_warning: - estimator.fit(triplets) - assert msg in str(raised_warning[0].message) - else: - estimator.fit(triplets) + estimator.fit(triplets) # We take the two first points and we build 4 regularly spaced points on the # line they define, so that it's easy to build triplets of different # similarities. diff --git a/test/test_utils.py b/test/test_utils.py index 83bdd86a..f3000344 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -117,7 +117,7 @@ def build_quadruplets(with_preprocessor=False): [learner for (learner, _) in quadruplets_learners])) -triplets_learners = [(SCML(), build_triplets)] +triplets_learners = [(SCML(n_basis=320), build_triplets)] ids_triplets_learners = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in triplets_learners])) @@ -140,7 +140,7 @@ def build_quadruplets(with_preprocessor=False): (RCA_Supervised(num_chunks=5), build_classification), (SDML_Supervised(prior='identity', balance_param=1e-5), build_classification), - (SCML_Supervised(), build_classification)] + (SCML_Supervised(n_basis=80), build_classification)] ids_classifiers = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in classifiers]))