Skip to content

Commit 6a4aaea

Browse files
authored
Fix 7 sources of warnings in the tests (#339)
* Fix 7 sources of warnings * Fix indentation * Generalized warnings, as old sklearn throw more warnings * Changed np.any() for any() * Fix identation
1 parent aaf8d44 commit 6a4aaea

File tree

9 files changed

+55
-17
lines changed

9 files changed

+55
-17
lines changed

metric_learn/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None,
704704
elif init == 'covariance':
705705
if input.ndim == 3:
706706
# if the input are tuples, we need to form an X by deduplication
707-
X = np.vstack({tuple(row) for row in input.reshape(-1, n_features)})
707+
X = np.unique(np.vstack(input), axis=0)
708708
else:
709709
X = input
710710
# atleast2d is necessary to deal with scalar covariance matrices

metric_learn/itml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _fit(self, pairs, y, bounds=None):
3232
type_of_inputs='tuples')
3333
# init bounds
3434
if bounds is None:
35-
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
35+
X = np.unique(np.vstack(pairs), axis=0)
3636
self.bounds_ = np.percentile(pairwise_distances(X), (5, 95))
3737
else:
3838
bounds = check_array(bounds, allow_nd=False, ensure_min_samples=0,

metric_learn/rca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def fit(self, X, chunks):
112112
# Fisher Linear Discriminant projection
113113
if dim < X.shape[1]:
114114
total_cov = np.cov(X[chunk_mask], rowvar=0)
115-
tmp = np.linalg.lstsq(total_cov, inner_cov)[0]
115+
tmp = np.linalg.lstsq(total_cov, inner_cov, rcond=None)[0]
116116
vals, vecs = np.linalg.eig(tmp)
117117
inds = np.argsort(vals)[:dim]
118118
A = vecs[:, inds]

metric_learn/scml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,10 @@ def _generate_bases_LDA(self, X, y):
615615
k_class = np.vstack((np.minimum(class_count, scales[0]),
616616
np.minimum(class_count, scales[1])))
617617

618-
idx_set = [np.zeros((n_clusters, sum(k_class[0, :])), dtype=np.int),
619-
np.zeros((n_clusters, sum(k_class[1, :])), dtype=np.int)]
618+
idx_set = [np.zeros((n_clusters, sum(k_class[0, :])), dtype=np.int64),
619+
np.zeros((n_clusters, sum(k_class[1, :])), dtype=np.int64)]
620620

621-
start_finish_indices = np.hstack((np.zeros((2, 1), np.int),
621+
start_finish_indices = np.hstack((np.zeros((2, 1), np.int64),
622622
k_class)).cumsum(axis=1)
623623

624624
neigh = NearestNeighbors()

pytest.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[pytest]
2+
markers =
3+
integration: mark a test as integration
4+
unit: mark a test as unit

test/metric_learn_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
make_spd_matrix)
1010
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
1111
assert_allclose)
12-
from metric_learn.sklearn_shims import assert_warns_message
1312
from sklearn.exceptions import ConvergenceWarning
1413
from sklearn.utils.validation import check_X_y
1514
from sklearn.preprocessing import StandardScaler
@@ -1143,9 +1142,10 @@ def test_convergence_warning(dataset, algo_class):
11431142
X, y = dataset
11441143
model = algo_class(max_iter=2, verbose=True)
11451144
cls_name = model.__class__.__name__
1146-
assert_warns_message(ConvergenceWarning,
1147-
'[{}] {} did not converge'.format(cls_name, cls_name),
1148-
model.fit, X, y)
1145+
msg = '[{}] {} did not converge'.format(cls_name, cls_name)
1146+
with pytest.warns(Warning) as raised_warning:
1147+
model.fit(X, y)
1148+
assert any([msg in str(warn.message) for warn in raised_warning])
11491149

11501150

11511151
if __name__ == '__main__':

test/test_constraints.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_generate_knntriplets_under_edge(k_genuine, k_impostor, T_test):
103103

104104

105105
@pytest.mark.parametrize("k_genuine, k_impostor,",
106-
[(2, 3), (3, 3), (2, 4), (3, 4)])
106+
[(3, 3), (2, 4), (3, 4), (10, 9), (144, 33)])
107107
def test_generate_knntriplets(k_genuine, k_impostor):
108108
"""Checks edge and over the edge cases of knn triplet construction with not
109109
enough neighbors"""
@@ -118,8 +118,23 @@ def test_generate_knntriplets(k_genuine, k_impostor):
118118
X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [33, 33]])
119119
y = np.array([1, 1, 1, 2, 2, 2, -1])
120120

121-
T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor)
122-
121+
msg1 = ("The class 1 has 3 elements, which is not sufficient to "
122+
f"generate {k_genuine+1} genuine neighbors "
123+
"as specified by k_genuine")
124+
msg2 = ("The class 2 has 3 elements, which is not sufficient to "
125+
f"generate {k_genuine+1} genuine neighbors "
126+
"as specified by k_genuine")
127+
msg3 = ("The class 1 has 3 elements of other classes, which is "
128+
f"not sufficient to generate {k_impostor} impostor "
129+
"neighbors as specified by k_impostor")
130+
msg4 = ("The class 2 has 3 elements of other classes, which is "
131+
f"not sufficient to generate {k_impostor} impostor "
132+
"neighbors as specified by k_impostor")
133+
msgs = [msg1, msg2, msg3, msg4]
134+
with pytest.warns(UserWarning) as user_warning:
135+
T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor)
136+
assert any([[msg in str(warn.message) for msg in msgs]
137+
for warn in user_warning])
123138
assert np.array_equal(sorted(T.tolist()), T_test)
124139

125140

test/test_sklearn_compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_cross_validation_manual_vs_scikit(estimator, build_dataset,
235235
n_splits = 3
236236
kfold = KFold(shuffle=False, n_splits=n_splits)
237237
n_samples = input_data.shape[0]
238-
fold_sizes = (n_samples // n_splits) * np.ones(n_splits, dtype=np.int)
238+
fold_sizes = (n_samples // n_splits) * np.ones(n_splits, dtype=np.int64)
239239
fold_sizes[:n_samples % n_splits] += 1
240240
current = 0
241241
scores, predictions = [], np.zeros(input_data.shape[0])

test/test_triplets_classifiers.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from sklearn.exceptions import NotFittedError
33
from sklearn.model_selection import train_test_split
4+
import metric_learn
45

56
from test.test_utils import triplets_learners, ids_triplets_learners
67
from metric_learn.sklearn_shims import set_random_state
@@ -20,7 +21,13 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset,
2021
estimator.set_params(preprocessor=preprocessor)
2122
set_random_state(estimator)
2223
triplets_train, triplets_test = train_test_split(input_data)
23-
estimator.fit(triplets_train)
24+
if isinstance(estimator, metric_learn.SCML):
25+
msg = "As no value for `n_basis` was selected, "
26+
with pytest.warns(UserWarning) as raised_warning:
27+
estimator.fit(triplets_train)
28+
assert msg in str(raised_warning[0].message)
29+
else:
30+
estimator.fit(triplets_train)
2431
predictions = estimator.predict(triplets_test)
2532

2633
not_valid = [e for e in predictions if e not in [-1, 1]]
@@ -42,7 +49,13 @@ def test_no_zero_prediction(estimator, build_dataset):
4249
# Dummy fit
4350
estimator = clone(estimator)
4451
set_random_state(estimator)
45-
estimator.fit(triplets)
52+
if isinstance(estimator, metric_learn.SCML):
53+
msg = "As no value for `n_basis` was selected, "
54+
with pytest.warns(UserWarning) as raised_warning:
55+
estimator.fit(triplets)
56+
assert msg in str(raised_warning[0].message)
57+
else:
58+
estimator.fit(triplets)
4659
# We force the transformation to be identity, to force euclidean distance
4760
estimator.components_ = np.eye(X.shape[1])
4861

@@ -93,7 +106,13 @@ def test_accuracy_toy_example(estimator, build_dataset):
93106
triplets, _, _, X = build_dataset(with_preprocessor=False)
94107
estimator = clone(estimator)
95108
set_random_state(estimator)
96-
estimator.fit(triplets)
109+
if isinstance(estimator, metric_learn.SCML):
110+
msg = "As no value for `n_basis` was selected, "
111+
with pytest.warns(UserWarning) as raised_warning:
112+
estimator.fit(triplets)
113+
assert msg in str(raised_warning[0].message)
114+
else:
115+
estimator.fit(triplets)
97116
# We take the two first points and we build 4 regularly spaced points on the
98117
# line they define, so that it's easy to build triplets of different
99118
# similarities.

0 commit comments

Comments
 (0)