Skip to content

Commit 65c88c8

Browse files
author
Joan Massich
committed
add warns context manger that mimics raises
1 parent a60399a commit 65c88c8

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

doc/developers_utils.rst

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,45 @@ On the top of all the functionality provided by scikit-learn. Imbalance-learn
101101
provides :func:`deprecate_parameter`: which is used to deprecate a sampler's
102102
parameter (attribute) by another one.
103103

104+
Testing utilities
105+
=================
106+
Currently, imbalanced-learn provide a warning management utility. This feature
107+
is going to be merge in pytest and will be removed when the pytest release will
108+
have it.
109+
110+
If using Python 2.7 or above, you may use this function as a
111+
context manager::
112+
113+
>>> import warnings
114+
>>> from imblearn.utils.testing import warns
115+
>>> with warns(RuntimeWarning):
116+
... warnings.warn("my runtime warning", RuntimeWarning)
117+
118+
>>> with warns(RuntimeWarning):
119+
... pass
120+
Traceback (most recent call last):
121+
...
122+
Failed: DID NOT WARN. No warnings of type ...RuntimeWarning... was emitted...
123+
124+
>>> with warns(RuntimeWarning):
125+
... warnings.warn(UserWarning)
126+
Traceback (most recent call last):
127+
...
128+
Failed: DID NOT WARN. No warnings of type ...RuntimeWarning... was emitted...
129+
130+
In the context manager form you may use the keyword argument ``match`` to assert
131+
that the exception matches a text or regex::
132+
133+
>>> import warnings
134+
>>> from imblearn.utils.testing import warns
135+
>>> with warns(UserWarning, match='must be 0 or None'):
136+
... warnings.warn("value must be 0 or None", UserWarning)
137+
138+
>>> with warns(UserWarning, match=r'must be \d+$'):
139+
... warnings.warn("value must be 42", UserWarning)
140+
141+
>>> with warns(UserWarning, match=r'must be \d+$'):
142+
... warnings.warn("this is not here", UserWarning)
143+
Traceback (most recent call last):
144+
...
145+
AssertionError: 'must be \d+$' pattern not found in ['this is not here']

imblearn/utils/testing.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
from sklearn.base import BaseEstimator
1616

17+
from pytest import warns as _warns
18+
from contextlib import contextmanager
19+
from re import compile
20+
21+
1722
# meta-estimators need another estimator to be instantiated.
1823
META_ESTIMATORS = []
1924
# estimators that there is no way to default-construct sensibly
@@ -120,3 +125,53 @@ def is_abstract(c):
120125
# itemgetter is used to ensure the sort does not extend to the 2nd item of
121126
# the tuple
122127
return sorted(set(estimators), key=itemgetter(0))
128+
129+
130+
@contextmanager
131+
def warns(expected_warning, match=None):
132+
"""Assert that a warning is raised with an optional matching pattern
133+
134+
Assert that a code block/function call warns ``expected_warning``
135+
and raise a failure exception otherwise. It can be used within a context
136+
manager ``with``.
137+
138+
Parameters
139+
----------
140+
expected_warning : Warning
141+
Warning type.
142+
143+
match : regex str or None, optional
144+
The pattern to be matched. By default, no check is done.
145+
146+
Returns
147+
-------
148+
None
149+
150+
Examples
151+
--------
152+
153+
>>> import warnings
154+
>>> from imblearn.utils.testing import warns
155+
>>> with warns(UserWarning, match=r'must be \d+$'):
156+
... warnings.warn("value must be 42", UserWarning)
157+
158+
>>> with warns(UserWarning, match=r'must be \d+$'):
159+
... warnings.warn("this is not here", UserWarning)
160+
Traceback (most recent call last):
161+
...
162+
AssertionError: 'must be \d+$' pattern not found in ['this is not here']
163+
...
164+
"""
165+
with _warns(expected_warning) as record:
166+
yield
167+
168+
if match is not None:
169+
for each in record:
170+
if compile(match).search(str(each.message)) is not None:
171+
break
172+
else:
173+
msg = "'{}' pattern not found in {}".format(
174+
match, '{}'.format([str(r.message) for r in record]))
175+
assert False, msg
176+
else:
177+
pass

imblearn/utils/tests/test_testing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from imblearn.base import SamplerMixin
99
from imblearn.utils.testing import all_estimators
1010

11+
from imblearn.utils.testing import warns
12+
1113

1214
def test_all_estimators():
1315
# check if the filtering is working with a list or a single string
@@ -23,3 +25,26 @@ def test_all_estimators():
2325
type_filter = 'rnd'
2426
with raises(ValueError, match="Parameter type_filter must be 'sampler'"):
2527
all_estimators(type_filter=type_filter)
28+
29+
30+
def test_warns():
31+
import warnings
32+
33+
with warns(UserWarning, match=r'must be \d+$'):
34+
warnings.warn("value must be 42", UserWarning)
35+
36+
with raises(AssertionError, match='pattern not found'):
37+
with warns(UserWarning, match=r'must be \d+$'):
38+
warnings.warn("this is not here", UserWarning)
39+
40+
with warns(UserWarning, match=r'aaa'):
41+
warnings.warn("cccccccccc", UserWarning)
42+
warnings.warn("bbbbbbbbbb", UserWarning)
43+
warnings.warn("aaaaaaaaaa", UserWarning)
44+
45+
a, b, c = ('aaa', 'bbbbbbbbbb', 'cccccccccc')
46+
expected_msg = "'{}' pattern not found in \['{}', '{}'\]".format(a, b, c)
47+
with raises(AssertionError, match=expected_msg):
48+
with warns(UserWarning, match=r'aaa'):
49+
warnings.warn("bbbbbbbbbb", UserWarning)
50+
warnings.warn("cccccccccc", UserWarning)

0 commit comments

Comments
 (0)