Skip to content

Commit c4c3feb

Browse files
author
Joan Massich
committed
add warns context manger that mimics raises
1 parent 2411d63 commit c4c3feb

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

doc/developers_utils.rst

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,45 @@ same information as the deprecation warning as explained above. Use the
9696
On the top of all the functionality provided by scikit-learn. Imbalance-learn
9797
provides :func:`deprecate_parameter`: which is used to deprecate a sampler's
9898
parameter (attribute) by another one.
99+
100+
Warning management when testing
101+
===============================
102+
103+
If using Python 2.7 or above, you may use this function as a
104+
context manager::
105+
106+
>>> import warnings
107+
>>> from imblearn.utils.testing import warns
108+
>>> with warns(RuntimeWarning):
109+
... warnings.warn("my runtime warning", RuntimeWarning)
110+
111+
>>> with warns(RuntimeWarning):
112+
... pass
113+
Traceback (most recent call last):
114+
...
115+
Failed: DID NOT WARN. No warnings of type (<class 'RuntimeWarning'>,) was emitted. The list of emitted warnings is: [].
116+
117+
>>> with warns(RuntimeWarning):
118+
... warnings.warn(UserWarning)
119+
Traceback (most recent call last):
120+
...
121+
Failed: DID NOT WARN. No warnings of type (<class 'RuntimeWarning'>,) was emitted. The list of emitted warnings is: [UserWarning(<class 'UserWarning'>,)].
122+
123+
124+
125+
In the context manager form you may use the keyword argument ``match`` to assert
126+
that the exception matches a text or regex::
127+
128+
>>> import warnings
129+
>>> from imblearn.utils.testing import warns
130+
>>> with warns(UserWarning, match='must be 0 or None'):
131+
... warnings.warn("value must be 0 or None", UserWarning)
132+
133+
>>> with warns(UserWarning, match=r'must be \d+$'):
134+
... warnings.warn("value must be 42", UserWarning)
135+
136+
>>> with warns(UserWarning, match=r'must be \d+$'):
137+
... warnings.warn("this is not here", UserWarning)
138+
Traceback (most recent call last):
139+
...
140+
AssertionError: 'must be \d+$' pattern not found in ['this is not here']

imblearn/utils/testing.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
from sklearn.base import BaseEstimator
1616

17+
from pytest import warns as _warns
18+
from contextlib import contextmanager
19+
from re import compile
20+
21+
1722
# meta-estimators need another estimator to be instantiated.
1823
META_ESTIMATORS = []
1924
# estimators that there is no way to default-construct sensibly
@@ -120,3 +125,25 @@ def is_abstract(c):
120125
# itemgetter is used to ensure the sort does not extend to the 2nd item of
121126
# the tuple
122127
return sorted(set(estimators), key=itemgetter(0))
128+
129+
130+
@contextmanager
131+
def warns(expected_warning, match=None):
132+
"""
133+
Assert that a code block/function call warns ``expected_warning``
134+
and raise a failure exception otherwise.
135+
136+
"""
137+
with _warns(expected_warning) as record:
138+
yield
139+
140+
if match is not None:
141+
for each in record:
142+
if compile(match).search(str(each.message)) is not None:
143+
break
144+
else:
145+
msg = "'{}' pattern not found in {}".format(
146+
match, '{}'.format([str(r.message) for r in record]))
147+
assert False, msg
148+
else:
149+
pass

imblearn/utils/tests/test_testing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from imblearn.base import SamplerMixin
99
from imblearn.utils.testing import all_estimators
1010

11+
from imblearn.utils.testing import warns
12+
1113

1214
def test_all_estimators():
1315
# check if the filtering is working with a list or a single string
@@ -23,3 +25,26 @@ def test_all_estimators():
2325
type_filter = 'rnd'
2426
with raises(ValueError, match="Parameter type_filter must be 'sampler'"):
2527
all_estimators(type_filter=type_filter)
28+
29+
30+
def test_warns():
31+
import warnings
32+
33+
with warns(UserWarning, match=r'must be \d+$'):
34+
warnings.warn("value must be 42", UserWarning)
35+
36+
with raises(AssertionError, match='pattern not found'):
37+
with warns(UserWarning, match=r'must be \d+$'):
38+
warnings.warn("this is not here", UserWarning)
39+
40+
with warns(UserWarning, match=r'aaa'):
41+
warnings.warn("cccccccccc", UserWarning)
42+
warnings.warn("bbbbbbbbbb", UserWarning)
43+
warnings.warn("aaaaaaaaaa", UserWarning)
44+
45+
a, b, c = ('aaa', 'bbbbbbbbbb', 'cccccccccc')
46+
expected_msg = "'{}' pattern not found in \['{}', '{}'\]".format(a, b, c)
47+
with raises(AssertionError, match=expected_msg):
48+
with warns(UserWarning, match=r'aaa'):
49+
warnings.warn("bbbbbbbbbb", UserWarning)
50+
warnings.warn("cccccccccc", UserWarning)

0 commit comments

Comments
 (0)