Skip to content

Commit 5b91b0e

Browse files
author
Joan Massich
committed
add warns context manger that mimics raises
1 parent a60399a commit 5b91b0e

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

doc/developers_utils.rst

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,44 @@ On the top of all the functionality provided by scikit-learn. Imbalance-learn
101101
provides :func:`deprecate_parameter`: which is used to deprecate a sampler's
102102
parameter (attribute) by another one.
103103

104+
Warning management when testing
105+
===============================
106+
107+
If using Python 2.7 or above, you may use this function as a
108+
context manager::
109+
110+
>>> import warnings
111+
>>> from imblearn.utils.testing import warns
112+
>>> with warns(RuntimeWarning):
113+
... warnings.warn("my runtime warning", RuntimeWarning)
114+
115+
>>> with warns(RuntimeWarning):
116+
... pass
117+
Traceback (most recent call last):
118+
...
119+
Failed: DID NOT WARN. No warnings of type (<.*RuntimeWarning.*>,) was emitted. The list of emitted warnings is: [].
120+
121+
>>> with warns(RuntimeWarning):
122+
... warnings.warn(UserWarning)
123+
Traceback (most recent call last):
124+
...
125+
Failed: DID NOT WARN. No warnings of type (<.*RuntimeWarning.*>,) was emitted. The list of emitted warnings is: [UserWarning(<class 'UserWarning'>,)].
126+
127+
128+
129+
In the context manager form you may use the keyword argument ``match`` to assert
130+
that the exception matches a text or regex::
131+
132+
>>> import warnings
133+
>>> from imblearn.utils.testing import warns
134+
>>> with warns(UserWarning, match='must be 0 or None'):
135+
... warnings.warn("value must be 0 or None", UserWarning)
136+
137+
>>> with warns(UserWarning, match=r'must be \d+$'):
138+
... warnings.warn("value must be 42", UserWarning)
139+
140+
>>> with warns(UserWarning, match=r'must be \d+$'):
141+
... warnings.warn("this is not here", UserWarning)
142+
Traceback (most recent call last):
143+
...
144+
AssertionError: 'must be \d+$' pattern not found in ['this is not here']

imblearn/utils/testing.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
from sklearn.base import BaseEstimator
1616

17+
from pytest import warns as _warns
18+
from contextlib import contextmanager
19+
from re import compile
20+
21+
1722
# meta-estimators need another estimator to be instantiated.
1823
META_ESTIMATORS = []
1924
# estimators that there is no way to default-construct sensibly
@@ -120,3 +125,25 @@ def is_abstract(c):
120125
# itemgetter is used to ensure the sort does not extend to the 2nd item of
121126
# the tuple
122127
return sorted(set(estimators), key=itemgetter(0))
128+
129+
130+
@contextmanager
131+
def warns(expected_warning, match=None):
132+
"""
133+
Assert that a code block/function call warns ``expected_warning``
134+
and raise a failure exception otherwise.
135+
136+
"""
137+
with _warns(expected_warning) as record:
138+
yield
139+
140+
if match is not None:
141+
for each in record:
142+
if compile(match).search(str(each.message)) is not None:
143+
break
144+
else:
145+
msg = "'{}' pattern not found in {}".format(
146+
match, '{}'.format([str(r.message) for r in record]))
147+
assert False, msg
148+
else:
149+
pass

imblearn/utils/tests/test_testing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from imblearn.base import SamplerMixin
99
from imblearn.utils.testing import all_estimators
1010

11+
from imblearn.utils.testing import warns
12+
1113

1214
def test_all_estimators():
1315
# check if the filtering is working with a list or a single string
@@ -23,3 +25,26 @@ def test_all_estimators():
2325
type_filter = 'rnd'
2426
with raises(ValueError, match="Parameter type_filter must be 'sampler'"):
2527
all_estimators(type_filter=type_filter)
28+
29+
30+
def test_warns():
31+
import warnings
32+
33+
with warns(UserWarning, match=r'must be \d+$'):
34+
warnings.warn("value must be 42", UserWarning)
35+
36+
with raises(AssertionError, match='pattern not found'):
37+
with warns(UserWarning, match=r'must be \d+$'):
38+
warnings.warn("this is not here", UserWarning)
39+
40+
with warns(UserWarning, match=r'aaa'):
41+
warnings.warn("cccccccccc", UserWarning)
42+
warnings.warn("bbbbbbbbbb", UserWarning)
43+
warnings.warn("aaaaaaaaaa", UserWarning)
44+
45+
a, b, c = ('aaa', 'bbbbbbbbbb', 'cccccccccc')
46+
expected_msg = "'{}' pattern not found in \['{}', '{}'\]".format(a, b, c)
47+
with raises(AssertionError, match=expected_msg):
48+
with warns(UserWarning, match=r'aaa'):
49+
warnings.warn("bbbbbbbbbb", UserWarning)
50+
warnings.warn("cccccccccc", UserWarning)

0 commit comments

Comments
 (0)