Skip to content

Commit 92f058a

Browse files
author
Joan Massich
committed
Move assert_warns to imblearn.utils.test.warns
1 parent 5b91b0e commit 92f058a

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

imblearn/under_sampling/prototype_selection/tests/test_nearmiss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from __future__ import print_function
77

88
import numpy as np
9-
from sklearn.utils.testing import assert_array_equal, assert_warns
9+
from sklearn.utils.testing import assert_array_equal
10+
from imblearn.utils.testing import warns
1011
from sklearn.neighbors import NearestNeighbors
1112
from pytest import raises
1213

@@ -36,7 +37,8 @@
3637
# FIXME remove at the end of the deprecation 0.4
3738
def test_nearmiss_deprecation():
3839
nm = NearMiss(ver3_samp_ngh=3, version=3)
39-
assert_warns(DeprecationWarning, nm.fit_sample, X, Y)
40+
with warns(DeprecationWarning, match="deprecated from 0.2"):
41+
nm.fit_sample(X, Y)
4042

4143

4244
def test_nearmiss_wrong_version():

imblearn/utils/estimator_checks.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
as sklearn_yield_all_checks, check_estimator \
1919
as sklearn_check_estimator, check_parameters_default_constructible
2020
from sklearn.exceptions import NotFittedError
21-
from sklearn.utils.testing import assert_warns
2221
from pytest import raises
22+
from imblearn.utils.testing import warns
2323

2424
from sklearn.utils.testing import set_random_state
2525

@@ -77,15 +77,8 @@ def check_target_type(name, Estimator):
7777
y = np.linspace(0, 1, 20)
7878
estimator = Estimator()
7979
set_random_state(estimator)
80-
assert_warns(UserWarning, estimator.fit, X, y)
81-
82-
83-
def check_multiclass_warning(name, Estimator):
84-
X = np.random.random((20, 2))
85-
y = np.array([0] * 3 + [1] * 2 + [2] * 15)
86-
estimator = Estimator()
87-
set_random_state(estimator)
88-
assert_warns(UserWarning, estimator.fit, X, y)
80+
with warns(UserWarning, match='should be of types'):
81+
estimator.fit(X, y)
8982

9083

9184
def multioutput_estimator_convert_y_2d(name, y):

0 commit comments

Comments
 (0)