Skip to content

Commit a60399a

Browse files
author
Joan Massich
committed
Change assert_raise_message to raises with regexp
1 parent 74cefc2 commit a60399a

File tree

2 files changed

+12
-33
lines changed

2 files changed

+12
-33
lines changed

imblearn/metrics/tests/test_classification.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from sklearn.utils.testing import assert_allclose, assert_array_equal
2020
from sklearn.utils.testing import assert_no_warnings
2121
from sklearn.utils.testing import assert_warns_message, ignore_warnings
22-
from sklearn.utils.testing import assert_raise_message
2322
from sklearn.metrics import accuracy_score, average_precision_score
2423
from sklearn.metrics import brier_score_loss, cohen_kappa_score
2524
from sklearn.metrics import jaccard_similarity_score, precision_score
@@ -400,10 +399,8 @@ def test_classification_report_imbalanced_multiclass_with_unicode_label():
400399
u'0.15 0.44 0.19 31 red\xa2 0.42 0.90 0.55 0.57 0.63 '
401400
u'0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75')
402401
if np_version[:3] < (1, 7, 0):
403-
expected_message = ("NumPy < 1.7.0 does not implement"
404-
" searchsorted on unicode data correctly.")
405-
assert_raise_message(RuntimeError, expected_message,
406-
classification_report_imbalanced, y_true, y_pred)
402+
with raises(RuntimeError, match="NumPy < 1.7.0"):
403+
classification_report_imbalanced(y_true, y_pred)
407404
else:
408405
report = classification_report_imbalanced(y_true, y_pred)
409406
assert _format_report(report) == expected_report

imblearn/tests/test_pipeline.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import time
1212

1313
import numpy as np
14-
from sklearn.utils.testing import assert_raise_message
1514
from sklearn.utils.testing import assert_array_equal
1615
from sklearn.utils.testing import assert_array_almost_equal
1716
from sklearn.utils.testing import assert_allclose
@@ -266,11 +265,8 @@ def test_pipeline_fit_params():
266265
assert pipe.named_steps['transf'].a is None
267266
assert pipe.named_steps['transf'].b is None
268267
# invalid parameters should raise an error message
269-
assert_raise_message(
270-
TypeError,
271-
"fit() got an unexpected keyword argument 'bad'",
272-
pipe.fit, None, None, clf__bad=True
273-
)
268+
with raises(TypeError, match="unexpected keyword argument"):
269+
pipe.fit(None, None, clf__bad=True)
274270

275271

276272
def test_pipeline_sample_weight_supported():
@@ -291,32 +287,19 @@ def test_pipeline_sample_weight_unsupported():
291287
pipe.fit(X, y=None)
292288
assert pipe.score(X) == 3
293289
assert pipe.score(X, sample_weight=None) == 3
294-
assert_raise_message(
295-
TypeError,
296-
"score() got an unexpected keyword argument 'sample_weight'",
297-
pipe.score, X, sample_weight=np.array([2, 3])
298-
)
290+
with raises(TypeError, match="unexpected keyword argument"):
291+
pipe.score(X, sample_weight=np.array([2, 3]))
299292

300293

301294
def test_pipeline_raise_set_params_error():
302295
# Test pipeline raises set params error message for nested models.
303296
pipe = Pipeline([('cls', LinearRegression())])
304-
305-
# expected error message
306-
error_msg = ('Invalid parameter %s for estimator %s. '
307-
'Check the list of available parameters '
308-
'with `estimator.get_params().keys()`.')
309-
310-
assert_raise_message(ValueError,
311-
error_msg % ('fake', 'Pipeline'),
312-
pipe.set_params,
313-
fake='nope')
297+
with raises(ValueError, match="Invalid parameter"):
298+
pipe.set_params(fake='nope')
314299

315300
# nested model check
316-
assert_raise_message(ValueError,
317-
error_msg % ("fake", pipe),
318-
pipe.set_params,
319-
fake__estimator='nope')
301+
with raises(ValueError, match="Invalid parameter"):
302+
pipe.set_params(fake__estimator='nope')
320303

321304

322305
def test_pipeline_methods_pca_svm():
@@ -537,9 +520,8 @@ def make():
537520
assert_array_equal([[exp]], pipeline.fit(X, y).transform(X))
538521
assert_array_equal([[exp]], pipeline.fit_transform(X, y))
539522
assert_array_equal(X, pipeline.inverse_transform([[exp]]))
540-
assert_raise_message(AttributeError,
541-
"'NoneType' object has no attribute 'predict'",
542-
getattr, pipeline, 'predict')
523+
with raises(AttributeError, match="has no attribute 'predict'"):
524+
getattr(pipeline, 'predict')
543525

544526
# Check None step at construction time
545527
exp = 2 * 5

0 commit comments

Comments
 (0)