Skip to content

Fix 7 sources of warnings in the tests #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metric_learn/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None,
elif init == 'covariance':
if input.ndim == 3:
# if the input are tuples, we need to form an X by deduplication
X = np.vstack({tuple(row) for row in input.reshape(-1, n_features)})
X = np.unique(np.vstack(input), axis=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be faster in some cases, and more RAM-hungry in others. Overall I think it's probably fine, and if users are particularly sensitive to the RAM requirement they can perform their own dedup and pass a 2-d input directly.

else:
X = input
# atleast2d is necessary to deal with scalar covariance matrices
Expand Down
2 changes: 1 addition & 1 deletion metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _fit(self, pairs, y, bounds=None):
type_of_inputs='tuples')
# init bounds
if bounds is None:
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
X = np.unique(np.vstack(pairs), axis=0)
self.bounds_ = np.percentile(pairwise_distances(X), (5, 95))
else:
bounds = check_array(bounds, allow_nd=False, ensure_min_samples=0,
Expand Down
2 changes: 1 addition & 1 deletion metric_learn/rca.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def fit(self, X, chunks):
# Fisher Linear Discriminant projection
if dim < X.shape[1]:
total_cov = np.cov(X[chunk_mask], rowvar=0)
tmp = np.linalg.lstsq(total_cov, inner_cov)[0]
tmp = np.linalg.lstsq(total_cov, inner_cov, rcond=None)[0]
vals, vecs = np.linalg.eig(tmp)
inds = np.argsort(vals)[:dim]
A = vecs[:, inds]
Expand Down
6 changes: 3 additions & 3 deletions metric_learn/scml.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,10 @@ def _generate_bases_LDA(self, X, y):
k_class = np.vstack((np.minimum(class_count, scales[0]),
np.minimum(class_count, scales[1])))

idx_set = [np.zeros((n_clusters, sum(k_class[0, :])), dtype=np.int),
np.zeros((n_clusters, sum(k_class[1, :])), dtype=np.int)]
idx_set = [np.zeros((n_clusters, sum(k_class[0, :])), dtype=np.int64),
np.zeros((n_clusters, sum(k_class[1, :])), dtype=np.int64)]

start_finish_indices = np.hstack((np.zeros((2, 1), np.int),
start_finish_indices = np.hstack((np.zeros((2, 1), np.int64),
k_class)).cumsum(axis=1)

neigh = NearestNeighbors()
Expand Down
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
markers =
integration: mark a test as integration
unit: mark a test as unit
8 changes: 4 additions & 4 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
make_spd_matrix)
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
assert_allclose)
from metric_learn.sklearn_shims import assert_warns_message
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.validation import check_X_y
from sklearn.preprocessing import StandardScaler
Expand Down Expand Up @@ -1143,9 +1142,10 @@ def test_convergence_warning(dataset, algo_class):
X, y = dataset
model = algo_class(max_iter=2, verbose=True)
cls_name = model.__class__.__name__
assert_warns_message(ConvergenceWarning,
'[{}] {} did not converge'.format(cls_name, cls_name),
model.fit, X, y)
msg = '[{}] {} did not converge'.format(cls_name, cls_name)
with pytest.warns(Warning) as raised_warning:
model.fit(X, y)
assert any([msg in str(warn.message) for warn in raised_warning])


if __name__ == '__main__':
Expand Down
21 changes: 18 additions & 3 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_generate_knntriplets_under_edge(k_genuine, k_impostor, T_test):


@pytest.mark.parametrize("k_genuine, k_impostor,",
[(2, 3), (3, 3), (2, 4), (3, 4)])
[(3, 3), (2, 4), (3, 4), (10, 9), (144, 33)])
def test_generate_knntriplets(k_genuine, k_impostor):
"""Checks edge and over the edge cases of knn triplet construction with not
enough neighbors"""
Expand All @@ -118,8 +118,23 @@ def test_generate_knntriplets(k_genuine, k_impostor):
X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [33, 33]])
y = np.array([1, 1, 1, 2, 2, 2, -1])

T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor)

msg1 = ("The class 1 has 3 elements, which is not sufficient to "
f"generate {k_genuine+1} genuine neighbors "
"as specified by k_genuine")
msg2 = ("The class 2 has 3 elements, which is not sufficient to "
f"generate {k_genuine+1} genuine neighbors "
"as specified by k_genuine")
msg3 = ("The class 1 has 3 elements of other classes, which is "
f"not sufficient to generate {k_impostor} impostor "
"neighbors as specified by k_impostor")
msg4 = ("The class 2 has 3 elements of other classes, which is "
f"not sufficient to generate {k_impostor} impostor "
"neighbors as specified by k_impostor")
msgs = [msg1, msg2, msg3, msg4]
with pytest.warns(UserWarning) as user_warning:
T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor)
assert any([[msg in str(warn.message) for msg in msgs]
for warn in user_warning])
assert np.array_equal(sorted(T.tolist()), T_test)


Expand Down
2 changes: 1 addition & 1 deletion test/test_mahalanobis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset):
'preprocessing step.')
with pytest.warns(UserWarning) as raised_warning:
model.fit(input_data, labels)
assert np.any([str(warning.message) == msg for warning in raised_warning])
assert any([str(warning.message) == msg for warning in raised_warning])
M, _ = _initialize_metric_mahalanobis(X, init='covariance',
random_state=RNG,
return_inverse=True,
Expand Down
2 changes: 1 addition & 1 deletion test/test_sklearn_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_cross_validation_manual_vs_scikit(estimator, build_dataset,
n_splits = 3
kfold = KFold(shuffle=False, n_splits=n_splits)
n_samples = input_data.shape[0]
fold_sizes = (n_samples // n_splits) * np.ones(n_splits, dtype=np.int)
fold_sizes = (n_samples // n_splits) * np.ones(n_splits, dtype=np.int64)
fold_sizes[:n_samples % n_splits] += 1
current = 0
scores, predictions = [], np.zeros(input_data.shape[0])
Expand Down
25 changes: 22 additions & 3 deletions test/test_triplets_classifiers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import train_test_split
import metric_learn

from test.test_utils import triplets_learners, ids_triplets_learners
from metric_learn.sklearn_shims import set_random_state
Expand All @@ -20,7 +21,13 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset,
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
triplets_train, triplets_test = train_test_split(input_data)
estimator.fit(triplets_train)
if isinstance(estimator, metric_learn.SCML):
msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
estimator.fit(triplets_train)
assert msg in str(raised_warning[0].message)
else:
estimator.fit(triplets_train)
predictions = estimator.predict(triplets_test)

not_valid = [e for e in predictions if e not in [-1, 1]]
Expand All @@ -42,7 +49,13 @@ def test_no_zero_prediction(estimator, build_dataset):
# Dummy fit
estimator = clone(estimator)
set_random_state(estimator)
estimator.fit(triplets)
if isinstance(estimator, metric_learn.SCML):
msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
estimator.fit(triplets)
assert msg in str(raised_warning[0].message)
else:
estimator.fit(triplets)
# We force the transformation to be identity, to force euclidean distance
estimator.components_ = np.eye(X.shape[1])

Expand Down Expand Up @@ -93,7 +106,13 @@ def test_accuracy_toy_example(estimator, build_dataset):
triplets, _, _, X = build_dataset(with_preprocessor=False)
estimator = clone(estimator)
set_random_state(estimator)
estimator.fit(triplets)
if isinstance(estimator, metric_learn.SCML):
msg = "As no value for `n_basis` was selected, "
with pytest.warns(UserWarning) as raised_warning:
estimator.fit(triplets)
assert msg in str(raised_warning[0].message)
else:
estimator.fit(triplets)
# We take the two first points and we build 4 regularly spaced points on the
# line they define, so that it's easy to build triplets of different
# similarities.
Expand Down