Skip to content

Commit 8da92de

Browse files
author
Joan Massich
committed
add warns context manger that mimics raises
1 parent f8ad59e commit 8da92de

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

imblearn/utils/testing.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
from sklearn.base import BaseEstimator
1616

17+
from pytest import warns as _warns
18+
from contextlib import contextmanager
19+
from re import compile
20+
21+
1722
# meta-estimators need another estimator to be instantiated.
1823
META_ESTIMATORS = []
1924
# estimators that there is no way to default-construct sensibly
@@ -120,3 +125,72 @@ def is_abstract(c):
120125
# itemgetter is used to ensure the sort does not extend to the 2nd item of
121126
# the tuple
122127
return sorted(set(estimators), key=itemgetter(0))
128+
129+
130+
@contextmanager
131+
def warns(expected_warning, message=None, match=None):
132+
"""
133+
Assert that a code block/function call warns ``expected_warning``
134+
and raise a failure exception otherwise.
135+
136+
If using Python 2.7 or above, you may use this function as a
137+
context manager::
138+
139+
>>> import warnings
140+
>>> with warns(RuntimeWarning):
141+
... warnings.warn("my runtime warning", RuntimeWarning)
142+
143+
>>> with warns(RuntimeWarning):
144+
... pass
145+
Traceback (most recent call last):
146+
...
147+
Failed: DID NOT WARN. No warnings of type (<class 'RuntimeError'>,) was emitted. The list of emitted warnings is: [].
148+
149+
>>> with warns(RuntimeWarning):
150+
... warnings.warn(UserWarning)
151+
Traceback (most recent call last):
152+
...
153+
Failed: DID NOT WARN. No warnings of type (<class 'RuntimeError'>,) was emitted. The list of emitted warnings is: [UserWarning(<class 'UserWarning'>,)].
154+
155+
156+
In the context manager form you may use the keyword argument
157+
``message`` to specify a custom failure message::
158+
159+
>>> import warnings
160+
>>> with warns(RuntimeWarning, message="my runtime warning"):
161+
... warnings.warn("different message", RuntimeWarning)
162+
Traceback (most recent call last):
163+
...
164+
AssertionError: "different message" is different from "my runtime warning"
165+
166+
Or you can use the keyword argument ``match`` to assert that the
167+
exception matches a text or regex::
168+
169+
>>> import warnings
170+
>>> with warns(UserWarning, match='must be 0 or None'):
171+
... warnings.warn("value must be 0 or None", UserWarning)
172+
173+
>>> with warns(UserWarning, match=r'must be \d+$'):
174+
... warnings.warn("value must be 42", UserWarning)
175+
176+
>>> with warns(UserWarning, match=r'must be \d+$'):
177+
... warnings.warn("this is not here", UserWarning)
178+
Traceback (most recent call last):
179+
...
180+
AssertionError: "must be \d+$" pattern not found in "this is not here"
181+
182+
"""
183+
with _warns(expected_warning) as record:
184+
yield
185+
186+
wrn_msg = str(record[0].message)
187+
if message is not None:
188+
# message matching has priority over regex match
189+
assert_msg = '"{}" is different from "{}"'.format(wrn_msg, message)
190+
assert wrn_msg == message, assert_msg
191+
elif match is not None:
192+
regex = compile(match)
193+
assert_msg = '"{}" pattern not found in "{}"'.format(match, wrn_msg)
194+
assert regex.search(wrn_msg) is not None, assert_msg
195+
else:
196+
pass

imblearn/utils/tests/test_testing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from imblearn.base import SamplerMixin
99
from imblearn.utils.testing import all_estimators
1010

11+
from imblearn.utils.testing import warns
12+
from pytest import raises
13+
1114

1215
def test_all_estimators():
1316
# check if the filtering is working with a list or a single string
@@ -23,3 +26,21 @@ def test_all_estimators():
2326
type_filter = 'rnd'
2427
with raises(ValueError, match="Parameter type_filter must be 'sampler'"):
2528
all_estimators(type_filter=type_filter)
29+
30+
31+
def test_warns():
32+
import warnings
33+
34+
with warns(RuntimeWarning, message='my runtime warning'):
35+
warnings.warn("my runtime warning", RuntimeWarning)
36+
37+
with warns(UserWarning, match=r'must be \d+$'):
38+
warnings.warn("value must be 42", UserWarning)
39+
40+
with raises(AssertionError, match='is different from'):
41+
with warns(RuntimeWarning, message="my runtime warning"):
42+
warnings.warn("different message", RuntimeWarning)
43+
44+
with raises(AssertionError, match='pattern not found'):
45+
with warns(UserWarning, match=r'must be \d+$'):
46+
warnings.warn("this is not here", UserWarning)

0 commit comments

Comments
 (0)