Skip to content

Commit bf63efc

Browse files
author
Joan Massich
committed
add warns context manger that mimics raises
1 parent 2411d63 commit bf63efc

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

doc/developers_utils.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,42 @@ same information as the deprecation warning as explained above. Use the
9696
On the top of all the functionality provided by scikit-learn. Imbalance-learn
9797
provides :func:`deprecate_parameter`: which is used to deprecate a sampler's
9898
parameter (attribute) by another one.
99+
100+
Warning management when testing
101+
===============================
102+
103+
If using Python 2.7 or above, you may use this function as a
104+
context manager::
105+
106+
>>> import warnings
107+
>>> with warns(RuntimeWarning):
108+
... warnings.warn("my runtime warning", RuntimeWarning)
109+
110+
>>> with warns(RuntimeWarning):
111+
... pass
112+
Traceback (most recent call last):
113+
...
114+
Failed: DID NOT WARN. No warnings of type (<class 'RuntimeError'>,) was emitted. The list of emitted warnings is: [].
115+
116+
>>> with warns(RuntimeWarning):
117+
... warnings.warn(UserWarning)
118+
Traceback (most recent call last):
119+
...
120+
Failed: DID NOT WARN. No warnings of type (<class 'RuntimeError'>,) was emitted. The list of emitted warnings is: [UserWarning(<class 'UserWarning'>,)].
121+
122+
123+
In the context manager form you may use the keyword argument ``match`` to assert
124+
that the exception matches a text or regex::
125+
126+
>>> import warnings
127+
>>> with warns(UserWarning, match='must be 0 or None'):
128+
... warnings.warn("value must be 0 or None", UserWarning)
129+
130+
>>> with warns(UserWarning, match=r'must be \d+$'):
131+
... warnings.warn("value must be 42", UserWarning)
132+
133+
>>> with warns(UserWarning, match=r'must be \d+$'):
134+
... warnings.warn("this is not here", UserWarning)
135+
Traceback (most recent call last):
136+
...
137+
AssertionError: "must be \d+$" pattern not found in "this is not here"

imblearn/utils/testing.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
from sklearn.base import BaseEstimator
1616

17+
from pytest import warns as _warns
18+
from contextlib import contextmanager
19+
from re import compile
20+
21+
1722
# meta-estimators need another estimator to be instantiated.
1823
META_ESTIMATORS = []
1924
# estimators that there is no way to default-construct sensibly
@@ -120,3 +125,25 @@ def is_abstract(c):
120125
# itemgetter is used to ensure the sort does not extend to the 2nd item of
121126
# the tuple
122127
return sorted(set(estimators), key=itemgetter(0))
128+
129+
130+
@contextmanager
131+
def warns(expected_warning, match=None):
132+
"""
133+
Assert that a code block/function call warns ``expected_warning``
134+
and raise a failure exception otherwise.
135+
136+
"""
137+
with _warns(expected_warning) as record:
138+
yield
139+
140+
if match is not None:
141+
for each in record:
142+
if compile(match).search(str(each.message)) is not None:
143+
break
144+
else:
145+
msg = "'{}' pattern not found in {}".format(
146+
match, '{}'.format([str(r.message) for r in record]))
147+
assert False, msg
148+
else:
149+
pass

imblearn/utils/tests/test_testing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from imblearn.base import SamplerMixin
99
from imblearn.utils.testing import all_estimators
1010

11+
from imblearn.utils.testing import warns
12+
1113

1214
def test_all_estimators():
1315
# check if the filtering is working with a list or a single string
@@ -23,3 +25,26 @@ def test_all_estimators():
2325
type_filter = 'rnd'
2426
with raises(ValueError, match="Parameter type_filter must be 'sampler'"):
2527
all_estimators(type_filter=type_filter)
28+
29+
30+
def test_warns():
31+
import warnings
32+
33+
with warns(UserWarning, match=r'must be \d+$'):
34+
warnings.warn("value must be 42", UserWarning)
35+
36+
with raises(AssertionError, match='pattern not found'):
37+
with warns(UserWarning, match=r'must be \d+$'):
38+
warnings.warn("this is not here", UserWarning)
39+
40+
with warns(UserWarning, match=r'aaa'):
41+
warnings.warn("cccccccccc", UserWarning)
42+
warnings.warn("bbbbbbbbbb", UserWarning)
43+
warnings.warn("aaaaaaaaaa", UserWarning)
44+
45+
a, b, c = ('aaa', 'bbbbbbbbbb', 'cccccccccc')
46+
expected_msg = "'{}' pattern not found in \['{}', '{}'\]".format(a, b, c)
47+
with raises(AssertionError, match=expected_msg):
48+
with warns(UserWarning, match=r'aaa'):
49+
warnings.warn("bbbbbbbbbb", UserWarning)
50+
warnings.warn("cccccccccc", UserWarning)

0 commit comments

Comments
 (0)