Skip to content

Commit 394d686

Browse files
committed
unit tests added for safe-level SMOTE
1 parent bcc3069 commit 394d686

File tree

2 files changed

+77
-11
lines changed

2 files changed

+77
-11
lines changed

imblearn/over_sampling/_smote.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,18 +1324,18 @@ class SLSMOTE(BaseSMOTE):
13241324
13251325
Parameters
13261326
-----------
1327-
{sampling_strategy}
1327+
{sampling_strategy}
13281328
1329-
{random_state}
1329+
{random_state}
13301330
1331-
k_neighbors : int or object, optional (default=5)
1331+
k_neighbors : int or object, optional (default=5)
13321332
If ``int``, number of nearest neighbours to used to construct synthetic
13331333
samples. If object, an estimator that inherits from
13341334
:class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to
13351335
find the k_neighbors.
13361336
13371337
m_neighbors : int or object, optional (default=10)
1338-
If ``int``, number of nearest neighbours to use to determine the safe
1338+
If ``int``, number of nearest neighbours used to determine the safe
13391339
level of an instance. If object, an estimator that inherits from
13401340
:class:`sklearn.neighbors.base.KNeighborsMixin` that will be used
13411341
to find the m_neighbors.
@@ -1582,16 +1582,16 @@ def _generate_gap(self, a_ratio, rand_state=None):
15821582
15831583
Parameters
15841584
----------
1585-
a_ratio: float
1586-
sl_ratio of a single data point
1585+
a_ratio: float
1586+
sl_ratio of a single data point
15871587
1588-
rand_state: random state object or int
1588+
rand_state: random state object or int
15891589
15901590
1591-
Returns
1592-
------------
1593-
gap: float
1594-
a number between 0 and 1
1591+
Returns
1592+
------------
1593+
gap: float
1594+
a number between 0 and 1
15951595
15961596
"""
15971597

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import pytest
2+
import numpy as np
3+
4+
from sklearn.neighbors import NearestNeighbors
5+
from scipy import sparse
6+
7+
from sklearn.utils._testing import assert_allclose
8+
from sklearn.utils._testing import assert_array_equal
9+
10+
from imblearn.over_sampling import SLSMOTE
11+
12+
13+
def data_np():
14+
rng = np.random.RandomState(42)
15+
X = rng.randn(20, 2)
16+
y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
17+
return X, y
18+
19+
20+
def data_sparse(format):
21+
X = sparse.random(20, 2, density=0.3, format=format, random_state=42)
22+
y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
23+
return X, y
24+
25+
26+
@pytest.mark.parametrize(
27+
"data",
28+
[data_np(), data_sparse('csr'), data_sparse('csc')]
29+
)
30+
def test_slsmote(data):
31+
y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0,
32+
0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0])
33+
X, y = data
34+
slsmote = SLSMOTE(random_state=42)
35+
X_res, y_res = slsmote.fit_resample(X, y)
36+
37+
assert X_res.shape == (24, 2)
38+
assert_array_equal(y_res, y_gt)
39+
40+
41+
def test_slsmote_nn():
42+
X, y = data_np()
43+
slsmote = SLSMOTE(random_state=42)
44+
slsmote_nn = SLSMOTE(
45+
random_state=42,
46+
k_neighbors=NearestNeighbors(n_neighbors=6),
47+
m_neighbors=NearestNeighbors(n_neighbors=11),
48+
)
49+
50+
X_res_1, y_res_1 = slsmote.fit_resample(X, y)
51+
X_res_2, y_res_2 = slsmote_nn.fit_resample(X, y)
52+
53+
assert_allclose(X_res_1, X_res_2)
54+
assert_array_equal(y_res_1, y_res_2)
55+
56+
57+
def test_slsmote_pd():
58+
pd = pytest.importorskip("pandas")
59+
X, y = data_np()
60+
X_pd = pd.DataFrame(X)
61+
slsmote = SLSMOTE(random_state=42)
62+
X_res, y_res = slsmote.fit_resample(X, y)
63+
X_res_pd, y_res_pd = slsmote.fit_resample(X_pd, y)
64+
65+
assert X_res_pd.tolist() == X_res.tolist()
66+
assert_allclose(y_res_pd, y_res)

0 commit comments

Comments
 (0)