Skip to content

[Tests][Warnings] Cut all warnings from SCML using a minimal solution #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_big_n_features(self):
n_informative=60, n_redundant=0, n_repeated=0,
random_state=42)
X = StandardScaler().fit_transform(X)
scml = SCML_Supervised(random_state=42)
scml = SCML_Supervised(random_state=42, n_basis=399)
scml.fit(X, y)
csep = class_separation(scml.transform(X), y)
assert csep < 0.7
Expand All @@ -106,7 +106,7 @@ def test_big_n_features(self):
[2, 0], [2, 1]]),
np.array([1, 0, 1, 0])))])
def test_bad_basis(self, estimator, data):
model = estimator(basis='bad_basis')
model = estimator(basis='bad_basis', n_basis=33) # n_basis doesn't matter
msg = ("`basis` must be one of the options '{}' or an array of shape "
"(n_basis, n_features)."
.format("', '".join(model._authorized_basis)))
Expand Down Expand Up @@ -238,16 +238,23 @@ def test_lda_toy(self):
@pytest.mark.parametrize('n_features', [10, 50, 100])
@pytest.mark.parametrize('n_classes', [5, 10, 15])
def test_triplet_diffs(self, n_samples, n_features, n_classes):
"""
Test that the correct value of n_basis is being generated with
different triplet constraints.
"""
X, y = make_classification(n_samples=n_samples, n_classes=n_classes,
n_features=n_features, n_informative=n_features,
n_redundant=0, n_repeated=0)
X = StandardScaler().fit_transform(X)

model = SCML_Supervised()
model = SCML_Supervised(n_basis=None) # Explicit n_basis=None
constraints = Constraints(y)
triplets = constraints.generate_knntriplets(X, model.k_genuine,
model.k_impostor)
basis, n_basis = model._generate_bases_dist_diff(triplets, X)

msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
basis, n_basis = model._generate_bases_dist_diff(triplets, X)
assert msg in str(raised_warning[0].message)

expected_n_basis = n_features * 80
assert n_basis == expected_n_basis
Expand All @@ -257,13 +264,21 @@ def test_triplet_diffs(self, n_samples, n_features, n_classes):
@pytest.mark.parametrize('n_features', [10, 50, 100])
@pytest.mark.parametrize('n_classes', [5, 10, 15])
def test_lda(self, n_samples, n_features, n_classes):
"""
Test that when n_basis=None, the correct n_basis is generated,
for SCML_Supervised and different values of n_samples, n_features
and n_classes.
"""
X, y = make_classification(n_samples=n_samples, n_classes=n_classes,
n_features=n_features, n_informative=n_features,
n_redundant=0, n_repeated=0)
X = StandardScaler().fit_transform(X)

model = SCML_Supervised()
basis, n_basis = model._generate_bases_LDA(X, y)
msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
model = SCML_Supervised(n_basis=None) # Explicit n_basis=None
basis, n_basis = model._generate_bases_LDA(X, y)
assert msg in str(raised_warning[0].message)

num_eig = min(n_classes - 1, n_features)
expected_n_basis = min(20 * n_features, n_samples * 2 * num_eig - 1)
Expand Down Expand Up @@ -299,7 +314,7 @@ def test_int_inputs_supervised(self, name):
assert msg == raised_error.value.args[0]

def test_large_output_iter(self):
scml = SCML(max_iter=1, output_iter=2)
scml = SCML(max_iter=1, output_iter=2, n_basis=33) # n_basis don't matter
triplets = np.array([[[0, 1], [2, 1], [0, 0]]])
msg = ("The value of output_iter must be equal or smaller than"
" max_iter.")
Expand Down
4 changes: 2 additions & 2 deletions test/test_mahalanobis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ def test_components_is_2D(estimator, build_dataset):
model.fit(*remove_y(estimator, input_data, labels))
assert model.components_.shape == (X.shape[1], X.shape[1])

# test that it works for 1 feature
trunc_data = input_data[..., :1]
# test that it works for 1 feature. Use 2nd dimention, to avoid border cases
trunc_data = input_data[..., 1:2]
# we drop duplicates that might have been formed, i.e. of the form
# aabc or abcc or aabb for quadruplets, and aa for pairs.

Expand Down
5 changes: 4 additions & 1 deletion test/test_sklearn_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def test_rca(self):
check_estimator(Stable_RCA_Supervised())

def test_scml(self):
check_estimator(SCML_Supervised())
msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
check_estimator(SCML_Supervised())
assert msg in str(raised_warning[0].message)


RNG = check_random_state(0)
Expand Down
25 changes: 3 additions & 22 deletions test/test_triplets_classifiers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import train_test_split
import metric_learn

from test.test_utils import triplets_learners, ids_triplets_learners
from metric_learn.sklearn_shims import set_random_state
Expand All @@ -21,13 +20,7 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset,
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
triplets_train, triplets_test = train_test_split(input_data)
if isinstance(estimator, metric_learn.SCML):
msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
estimator.fit(triplets_train)
assert msg in str(raised_warning[0].message)
else:
estimator.fit(triplets_train)
estimator.fit(triplets_train)
predictions = estimator.predict(triplets_test)

not_valid = [e for e in predictions if e not in [-1, 1]]
Expand All @@ -49,13 +42,7 @@ def test_no_zero_prediction(estimator, build_dataset):
# Dummy fit
estimator = clone(estimator)
set_random_state(estimator)
if isinstance(estimator, metric_learn.SCML):
msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
estimator.fit(triplets)
assert msg in str(raised_warning[0].message)
else:
estimator.fit(triplets)
estimator.fit(triplets)
# We force the transformation to be identity, to force euclidean distance
estimator.components_ = np.eye(X.shape[1])

Expand Down Expand Up @@ -106,13 +93,7 @@ def test_accuracy_toy_example(estimator, build_dataset):
triplets, _, _, X = build_dataset(with_preprocessor=False)
estimator = clone(estimator)
set_random_state(estimator)
if isinstance(estimator, metric_learn.SCML):
msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
estimator.fit(triplets)
assert msg in str(raised_warning[0].message)
else:
estimator.fit(triplets)
estimator.fit(triplets)
# We take the two first points and we build 4 regularly spaced points on the
# line they define, so that it's easy to build triplets of different
# similarities.
Expand Down
4 changes: 2 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def build_quadruplets(with_preprocessor=False):
[learner for (learner, _) in
quadruplets_learners]))

triplets_learners = [(SCML(), build_triplets)]
triplets_learners = [(SCML(n_basis=320), build_triplets)]
ids_triplets_learners = list(map(lambda x: x.__class__.__name__,
[learner for (learner, _) in
triplets_learners]))
Expand All @@ -140,7 +140,7 @@ def build_quadruplets(with_preprocessor=False):
(RCA_Supervised(num_chunks=5), build_classification),
(SDML_Supervised(prior='identity', balance_param=1e-5),
build_classification),
(SCML_Supervised(), build_classification)]
(SCML_Supervised(n_basis=80), build_classification)]
ids_classifiers = list(map(lambda x: x.__class__.__name__,
[learner for (learner, _) in
classifiers]))
Expand Down