Skip to content

Commit 488a0e8

Browse files
massichglemaitre
authored andcommitted
[MRG] Migrate raising errors from nose to pytest (#321)
* change assert_raise for raises(xxx) * Migrate assert_raises_regex to pytest.raises * Change assert_raise_message to raises with regexp * add warns context manger that mimics raises * Move assert_warns to imblearn.utils.test.warns * migrate assert_warns_message to imblearn.utils.testing.warns * Move import statements for codebase coherence.
1 parent a2de53a commit 488a0e8

26 files changed

+400
-266
lines changed

doc/developers_utils.rst

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,45 @@ On the top of all the functionality provided by scikit-learn. Imbalance-learn
101101
provides :func:`deprecate_parameter`: which is used to deprecate a sampler's
102102
parameter (attribute) by another one.
103103

104+
Testing utilities
105+
=================
106+
Currently, imbalanced-learn provide a warning management utility. This feature
107+
is going to be merge in pytest and will be removed when the pytest release will
108+
have it.
109+
110+
If using Python 2.7 or above, you may use this function as a
111+
context manager::
112+
113+
>>> import warnings
114+
>>> from imblearn.utils.testing import warns
115+
>>> with warns(RuntimeWarning):
116+
... warnings.warn("my runtime warning", RuntimeWarning)
117+
118+
>>> with warns(RuntimeWarning):
119+
... pass
120+
Traceback (most recent call last):
121+
...
122+
Failed: DID NOT WARN. No warnings of type ...RuntimeWarning... was emitted...
123+
124+
>>> with warns(RuntimeWarning):
125+
... warnings.warn(UserWarning)
126+
Traceback (most recent call last):
127+
...
128+
Failed: DID NOT WARN. No warnings of type ...RuntimeWarning... was emitted...
129+
130+
In the context manager form you may use the keyword argument ``match`` to assert
131+
that the exception matches a text or regex::
132+
133+
>>> import warnings
134+
>>> from imblearn.utils.testing import warns
135+
>>> with warns(UserWarning, match='must be 0 or None'):
136+
... warnings.warn("value must be 0 or None", UserWarning)
137+
138+
>>> with warns(UserWarning, match=r'must be \d+$'):
139+
... warnings.warn("value must be 42", UserWarning)
140+
141+
>>> with warns(UserWarning, match=r'must be \d+$'):
142+
... warnings.warn("this is not here", UserWarning)
143+
Traceback (most recent call last):
144+
...
145+
AssertionError: 'must be \d+$' pattern not found in ['this is not here']

imblearn/combine/tests/test_smote_enn.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from __future__ import print_function
77

88
import numpy as np
9+
from pytest import raises
10+
911
from sklearn.utils.testing import assert_allclose, assert_array_equal
10-
from sklearn.utils.testing import assert_raises_regex
1112

1213
from imblearn.combine import SMOTEENN
1314
from imblearn.under_sampling import EditedNearestNeighbours
@@ -113,8 +114,8 @@ def test_error_wrong_object():
113114
smote = 'rnd'
114115
enn = 'rnd'
115116
smt = SMOTEENN(smote=smote, random_state=RND_SEED)
116-
assert_raises_regex(ValueError, "smote needs to be a SMOTE",
117-
smt.fit_sample, X, Y)
117+
with raises(ValueError, match="smote needs to be a SMOTE"):
118+
smt.fit_sample(X, Y)
118119
smt = SMOTEENN(enn=enn, random_state=RND_SEED)
119-
assert_raises_regex(ValueError, "enn needs to be an ",
120-
smt.fit_sample, X, Y)
120+
with raises(ValueError, match="enn needs to be an "):
121+
smt.fit_sample(X, Y)

imblearn/combine/tests/test_smote_tomek.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from __future__ import print_function
77

88
import numpy as np
9+
from pytest import raises
10+
911
from sklearn.utils.testing import assert_allclose, assert_array_equal
10-
from sklearn.utils.testing import assert_raises_regex
1112

1213
from imblearn.combine import SMOTETomek
1314
from imblearn.over_sampling import SMOTE
@@ -156,8 +157,8 @@ def test_error_wrong_object():
156157
smote = 'rnd'
157158
tomek = 'rnd'
158159
smt = SMOTETomek(smote=smote, random_state=RND_SEED)
159-
assert_raises_regex(ValueError, "smote needs to be a SMOTE",
160-
smt.fit_sample, X, Y)
160+
with raises(ValueError, match="smote needs to be a SMOTE"):
161+
smt.fit_sample(X, Y)
161162
smt = SMOTETomek(tomek=tomek, random_state=RND_SEED)
162-
assert_raises_regex(ValueError, "tomek needs to be a TomekLinks",
163-
smt.fit_sample, X, Y)
163+
with raises(ValueError, match="tomek needs to be a TomekLinks"):
164+
smt.fit_sample(X, Y)

imblearn/datasets/tests/test_imbalance.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
import numpy as np
1212

13+
from pytest import raises
14+
1315
from sklearn.datasets import load_iris
14-
from sklearn.utils.testing import assert_raises_regex
15-
from sklearn.utils.testing import assert_warns_message
1616

17+
from imblearn.utils.testing import warns
1718
from imblearn.datasets import make_imbalance
1819

1920
data = load_iris()
@@ -24,28 +25,28 @@ def test_make_imbalance_error():
2425
# we are reusing part of utils.check_ratio, however this is not cover in
2526
# the common tests so we will repeat it here
2627
ratio = {0: -100, 1: 50, 2: 50}
27-
assert_raises_regex(ValueError, "in a class cannot be negative",
28-
make_imbalance, X, Y, ratio)
28+
with raises(ValueError, match="in a class cannot be negative"):
29+
make_imbalance(X, Y, ratio)
2930
ratio = {0: 10, 1: 70}
30-
assert_raises_regex(ValueError, "should be less or equal to the original",
31-
make_imbalance, X, Y, ratio)
31+
with raises(ValueError, match="should be less or equal to the original"):
32+
make_imbalance(X, Y, ratio)
3233
y_ = np.zeros((X.shape[0], ))
3334
ratio = {0: 10}
34-
assert_raises_regex(ValueError, "needs to have more than 1 class.",
35-
make_imbalance, X, y_, ratio)
35+
with raises(ValueError, match="needs to have more than 1 class."):
36+
make_imbalance(X, y_, ratio)
3637
ratio = 'random-string'
37-
assert_raises_regex(ValueError, "has to be a dictionary or a function",
38-
make_imbalance, X, Y, ratio)
38+
with raises(ValueError, match="has to be a dictionary or a function"):
39+
make_imbalance(X, Y, ratio)
3940

4041

4142
# FIXME: to be removed in 0.4 due to deprecation
4243
def test_make_imbalance_float():
43-
X_, y_ = assert_warns_message(DeprecationWarning,
44-
"'min_c_' is deprecated in 0.2",
45-
make_imbalance, X, Y, ratio=0.5, min_c_=1)
46-
X_, y_ = assert_warns_message(DeprecationWarning,
47-
"'ratio' being a float is deprecated",
48-
make_imbalance, X, Y, ratio=0.5, min_c_=1)
44+
with warns(DeprecationWarning, match="deprecated in 0.2"):
45+
X_, y_ = make_imbalance(X, Y, ratio=0.5, min_c_=1)
46+
47+
with warns(DeprecationWarning, match="'ratio' being a float"):
48+
X_, y_ = make_imbalance(X, Y, ratio=0.5, min_c_=1)
49+
4950
assert Counter(y_) == {0: 50, 1: 25, 2: 50}
5051
# resample without using min_c_
5152
X_, y_ = make_imbalance(X_, y_, ratio=0.25, min_c_=None)

imblearn/datasets/tests/test_zenodo.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from imblearn.datasets import fetch_datasets
1010
from sklearn.utils.testing import SkipTest, assert_allclose
11-
from sklearn.utils.testing import assert_raises_regex
11+
12+
from pytest import raises
1213

1314
DATASET_SHAPE = {'ecoli': (336, 7),
1415
'optical_digits': (5620, 64),
@@ -84,11 +85,11 @@ def test_fetch_filter():
8485

8586

8687
def test_fetch_error():
87-
assert_raises_regex(ValueError, 'is not a dataset available.',
88-
fetch_datasets, filter_data=tuple(['rnd']))
89-
assert_raises_regex(ValueError, 'dataset with the ID=',
90-
fetch_datasets, filter_data=tuple([-1]))
91-
assert_raises_regex(ValueError, 'dataset with the ID=',
92-
fetch_datasets, filter_data=tuple([100]))
93-
assert_raises_regex(ValueError, 'value in the tuple',
94-
fetch_datasets, filter_data=tuple([1.00]))
88+
with raises(ValueError, match='is not a dataset available.'):
89+
fetch_datasets(filter_data=tuple(['rnd']))
90+
with raises(ValueError, match='dataset with the ID='):
91+
fetch_datasets(filter_data=tuple([-1]))
92+
with raises(ValueError, match='dataset with the ID='):
93+
fetch_datasets(filter_data=tuple([100]))
94+
with raises(ValueError, match='value in the tuple'):
95+
fetch_datasets(filter_data=tuple([1.00]))

imblearn/ensemble/tests/test_balance_cascade.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
from __future__ import print_function
77

88
import numpy as np
9-
from sklearn.utils.testing import assert_array_equal, assert_raises
10-
from sklearn.utils.testing import assert_raises_regex
9+
10+
from pytest import raises
11+
12+
from sklearn.utils.testing import assert_array_equal
1113
from sklearn.ensemble import RandomForestClassifier
1214

1315
from imblearn.ensemble import BalanceCascade
1416

17+
1518
RND_SEED = 0
1619
X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141],
1720
[1.25192108, -0.22367336], [0.53366841, -0.30312976],
@@ -299,7 +302,8 @@ def test_fit_sample_auto_linear_svm():
299302
def test_init_wrong_classifier():
300303
classifier = 'rnd'
301304
bc = BalanceCascade(classifier=classifier)
302-
assert_raises(NotImplementedError, bc.fit_sample, X, Y)
305+
with raises(NotImplementedError):
306+
bc.fit_sample(X, Y)
303307

304308

305309
def test_fit_sample_auto_early_stop():
@@ -362,5 +366,5 @@ def test_give_classifier_wrong_obj():
362366
classifier = 2
363367
bc = BalanceCascade(ratio=ratio, random_state=RND_SEED,
364368
return_indices=True, estimator=classifier)
365-
assert_raises_regex(ValueError, "Invalid parameter `estimator`",
366-
bc.fit_sample, X, Y)
369+
with raises(ValueError, match="Invalid parameter `estimator`"):
370+
bc.fit_sample(X, Y)

imblearn/metrics/tests/test_classification.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@
1010

1111
import numpy as np
1212

13+
from pytest import approx, raises
14+
1315
from sklearn import datasets
1416
from sklearn import svm
1517

1618
from sklearn.preprocessing import label_binarize
1719
from sklearn.utils.fixes import np_version
1820
from sklearn.utils.validation import check_random_state
1921
from sklearn.utils.testing import assert_allclose, assert_array_equal
20-
from sklearn.utils.testing import assert_no_warnings, assert_raises
21-
from sklearn.utils.testing import assert_warns_message, ignore_warnings
22-
from sklearn.utils.testing import assert_raise_message
22+
from sklearn.utils.testing import assert_no_warnings
23+
from sklearn.utils.testing import ignore_warnings
2324
from sklearn.metrics import accuracy_score, average_precision_score
2425
from sklearn.metrics import brier_score_loss, cohen_kappa_score
2526
from sklearn.metrics import jaccard_similarity_score, precision_score
@@ -32,7 +33,8 @@
3233
from imblearn.metrics import make_index_balanced_accuracy
3334
from imblearn.metrics import classification_report_imbalanced
3435

35-
from pytest import approx
36+
from imblearn.utils.testing import warns
37+
3638

3739
RND_SEED = 42
3840
R_TOL = 1e-2
@@ -177,40 +179,30 @@ def test_sensitivity_specificity_error_multilabels():
177179
y_true_bin = label_binarize(y_true, classes=np.arange(5))
178180
y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
179181

180-
assert_raises(ValueError, sensitivity_score, y_true_bin, y_pred_bin)
182+
with raises(ValueError):
183+
sensitivity_score(y_true_bin, y_pred_bin)
181184

182185

183186
@ignore_warnings
184187
def test_sensitivity_specificity_support_errors():
185188
y_true, y_pred, _ = make_prediction(binary=True)
186189

187190
# Bad pos_label
188-
assert_raises(
189-
ValueError,
190-
sensitivity_specificity_support,
191-
y_true,
192-
y_pred,
193-
pos_label=2,
194-
average='binary')
191+
with raises(ValueError):
192+
sensitivity_specificity_support(y_true, y_pred, pos_label=2,
193+
average='binary')
195194

196195
# Bad average option
197-
assert_raises(
198-
ValueError,
199-
sensitivity_specificity_support, [0, 1, 2], [1, 2, 0],
200-
average='mega')
196+
with raises(ValueError):
197+
sensitivity_specificity_support([0, 1, 2], [1, 2, 0], average='mega')
201198

202199

203200
def test_sensitivity_specificity_unused_pos_label():
204201
# but average != 'binary'; even if data is binary
205-
assert_warns_message(
206-
UserWarning,
207-
"Note that pos_label (set to 2) is "
208-
"ignored when average != 'binary' (got 'macro'). You "
209-
"may use labels=[pos_label] to specify a single "
210-
"positive class.",
211-
sensitivity_specificity_support, [1, 2, 1], [1, 2, 2],
212-
pos_label=2,
213-
average='macro')
202+
with warns(UserWarning, "use labels=\[pos_label\] to specify a single"):
203+
sensitivity_specificity_support([1, 2, 1], [1, 2, 2],
204+
pos_label=2,
205+
average='macro')
214206

215207

216208
def test_geometric_mean_support_binary():
@@ -405,10 +397,8 @@ def test_classification_report_imbalanced_multiclass_with_unicode_label():
405397
u'0.15 0.44 0.19 31 red\xa2 0.42 0.90 0.55 0.57 0.63 '
406398
u'0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75')
407399
if np_version[:3] < (1, 7, 0):
408-
expected_message = ("NumPy < 1.7.0 does not implement"
409-
" searchsorted on unicode data correctly.")
410-
assert_raise_message(RuntimeError, expected_message,
411-
classification_report_imbalanced, y_true, y_pred)
400+
with raises(RuntimeError, match="NumPy < 1.7.0"):
401+
classification_report_imbalanced(y_true, y_pred)
412402
else:
413403
report = classification_report_imbalanced(y_true, y_pred)
414404
assert _format_report(report) == expected_report
@@ -459,16 +449,20 @@ def test_iba_error_y_score_prob():
459449

460450
aps = make_index_balanced_accuracy(alpha=0.5, squared=True)(
461451
average_precision_score)
462-
assert_raises(AttributeError, aps, y_true, y_pred)
452+
with raises(AttributeError):
453+
aps(y_true, y_pred)
463454

464455
brier = make_index_balanced_accuracy(alpha=0.5, squared=True)(
465456
brier_score_loss)
466-
assert_raises(AttributeError, brier, y_true, y_pred)
457+
with raises(AttributeError):
458+
brier(y_true, y_pred)
467459

468460
kappa = make_index_balanced_accuracy(alpha=0.5, squared=True)(
469461
cohen_kappa_score)
470-
assert_raises(AttributeError, kappa, y_true, y_pred)
462+
with raises(AttributeError):
463+
kappa(y_true, y_pred)
471464

472465
ras = make_index_balanced_accuracy(alpha=0.5, squared=True)(
473466
roc_auc_score)
474-
assert_raises(AttributeError, ras, y_true, y_pred)
467+
with raises(AttributeError):
468+
ras(y_true, y_pred)

imblearn/over_sampling/tests/test_adasyn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from __future__ import print_function
77

88
import numpy as np
9+
from pytest import raises
10+
911
from sklearn.utils.testing import assert_allclose, assert_array_equal
10-
from sklearn.utils.testing import assert_raises_regex
1112
from sklearn.neighbors import NearestNeighbors
1213

1314
from imblearn.over_sampling import ADASYN
1415

16+
1517
RND_SEED = 0
1618
X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141],
1719
[1.25192108, -0.22367336], [0.53366841, -0.30312976],
@@ -141,5 +143,5 @@ def test_ada_fit_sample_nn_obj():
141143
def test_ada_wrong_nn_obj():
142144
nn = 'rnd'
143145
ada = ADASYN(random_state=RND_SEED, n_neighbors=nn)
144-
assert_raises_regex(ValueError, "has to be one of",
145-
ada.fit_sample, X, Y)
146+
with raises(ValueError, match="has to be one of"):
147+
ada.fit_sample(X, Y)

0 commit comments

Comments
 (0)