|
14 | 14 |
|
15 | 15 | from sklearn.base import BaseEstimator
|
16 | 16 |
|
| 17 | +from pytest import warns as _warns |
| 18 | +from contextlib import contextmanager |
| 19 | +from re import compile |
| 20 | + |
| 21 | + |
17 | 22 | # meta-estimators need another estimator to be instantiated.
|
18 | 23 | META_ESTIMATORS = []
|
19 | 24 | # estimators that there is no way to default-construct sensibly
|
@@ -120,3 +125,72 @@ def is_abstract(c):
|
120 | 125 | # itemgetter is used to ensure the sort does not extend to the 2nd item of
|
121 | 126 | # the tuple
|
122 | 127 | return sorted(set(estimators), key=itemgetter(0))
|
| 128 | + |
| 129 | + |
| 130 | +@contextmanager |
| 131 | +def warns(expected_warning, message=None, match=None): |
| 132 | + """ |
| 133 | + Assert that a code block/function call warns ``expected_warning`` |
| 134 | + and raise a failure exception otherwise. |
| 135 | +
|
| 136 | + If using Python 2.7 or above, you may use this function as a |
| 137 | + context manager:: |
| 138 | +
|
| 139 | + >>> import warnings |
| 140 | + >>> with warns(RuntimeWarning): |
| 141 | + ... warnings.warn("my runtime warning", RuntimeWarning) |
| 142 | +
|
| 143 | + >>> with warns(RuntimeWarning): |
| 144 | + ... pass |
| 145 | + Traceback (most recent call last): |
| 146 | + ... |
| 147 | + Failed: DID NOT WARN. No warnings of type (<class 'RuntimeError'>,) was emitted. The list of emitted warnings is: []. |
| 148 | +
|
| 149 | + >>> with warns(RuntimeWarning): |
| 150 | + ... warnings.warn(UserWarning) |
| 151 | + Traceback (most recent call last): |
| 152 | + ... |
| 153 | + Failed: DID NOT WARN. No warnings of type (<class 'RuntimeError'>,) was emitted. The list of emitted warnings is: [UserWarning(<class 'UserWarning'>,)]. |
| 154 | +
|
| 155 | +
|
| 156 | + In the context manager form you may use the keyword argument |
| 157 | + ``message`` to specify a custom failure message:: |
| 158 | +
|
| 159 | + >>> import warnings |
| 160 | + >>> with warns(RuntimeWarning, message="my runtime warning"): |
| 161 | + ... warnings.warn("different message", RuntimeWarning) |
| 162 | + Traceback (most recent call last): |
| 163 | + ... |
| 164 | + AssertionError: "different message" is different from "my runtime warning" |
| 165 | +
|
| 166 | + Or you can use the keyword argument ``match`` to assert that the |
| 167 | + exception matches a text or regex:: |
| 168 | +
|
| 169 | + >>> import warnings |
| 170 | + >>> with warns(UserWarning, match='must be 0 or None'): |
| 171 | + ... warnings.warn("value must be 0 or None", UserWarning) |
| 172 | +
|
| 173 | + >>> with warns(UserWarning, match=r'must be \d+$'): |
| 174 | + ... warnings.warn("value must be 42", UserWarning) |
| 175 | +
|
| 176 | + >>> with warns(UserWarning, match=r'must be \d+$'): |
| 177 | + ... warnings.warn("this is not here", UserWarning) |
| 178 | + Traceback (most recent call last): |
| 179 | + ... |
| 180 | + AssertionError: "must be \d+$" pattern not found in "this is not here" |
| 181 | +
|
| 182 | + """ |
| 183 | + with _warns(expected_warning) as record: |
| 184 | + yield |
| 185 | + |
| 186 | + wrn_msg = str(record[0].message) |
| 187 | + if message is not None: |
| 188 | + # message matching has priority over regex match |
| 189 | + assert_msg = '"{}" is different from "{}"'.format(wrn_msg, message) |
| 190 | + assert wrn_msg == message, assert_msg |
| 191 | + elif match is not None: |
| 192 | + regex = compile(match) |
| 193 | + assert_msg = '"{}" pattern not found in "{}"'.format(match, wrn_msg) |
| 194 | + assert regex.search(wrn_msg) is not None, assert_msg |
| 195 | + else: |
| 196 | + pass |
0 commit comments