Skip to content

Commit 609c4fc

Browse files
committed
fixed variable name, added doc and test
1 parent 394d686 commit 609c4fc

File tree

4 files changed

+98
-49
lines changed

4 files changed

+98
-49
lines changed

doc/over_sampling.rst

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,9 @@ nearest neighbors class. Those variants are presented in the figure below.
152152
:align: center
153153

154154

155-
The :class:`BorderlineSMOTE` [HWB2005]_, :class:`SVMSMOTE` [NCK2009]_, and
156-
:class:`KMeansSMOTE` [LDB2017]_ offer some variant of the SMOTE algorithm::
155+
The :class:`BorderlineSMOTE` [HWB2005]_, :class:`SVMSMOTE` [NCK2009]_,
156+
:class:`KMeansSMOTE` [LDB2017]_ and :class:`SafeLevelSMOTE` [BSL2009]_
157+
offer some variant of the SMOTE algorithm::
157158

158159
>>> from imblearn.over_sampling import BorderlineSMOTE
159160
>>> X_resampled, y_resampled = BorderlineSMOTE().fit_resample(X, y)
@@ -213,6 +214,14 @@ other extra interpolation.
213214
Imbalanced Learning Based on K-Means and SMOTE"
214215
https://arxiv.org/abs/1711.00837
215216
217+
[BSL2009] C. Bunkhumpornpat, K. Sinapiromsaran, C. Lursinsap,
218+
"Safe-level-SMOTE: Safe-level-synthetic minority over-sampling
219+
technique for handling the class imbalanced problem," In:
220+
Theeramunkong T., Kijsirikul B., Cercone N., Ho TB. (eds)
221+
Advances in Knowledge Discovery and Data Mining. PAKDD 2009.
222+
Lecture Notes in Computer Science, vol 5476. Springer, Berlin,
223+
Heidelberg, 475-482, 2009.
224+
216225
Mathematical formulation
217226
========================
218227

@@ -274,6 +283,11 @@ parameter ``m_neighbors`` to decide if a sample is in danger, safe, or noise.
274283
method before to apply SMOTE. The clustering will group samples together and
275284
generate new samples depending of the cluster density.
276285

286+
**SafeLevel** SMOTE --- cf. to :class:`SafeLevelSMOTE` --- uses the safe level
287+
(the number of positive instances in nearest neighbors) to generate a synthetic
288+
instance. Compared to regular SMOTE, the new instance is positioned closer to
289+
the positive instance with larger safe level.
290+
277291
ADASYN works similarly to the regular SMOTE. However, the number of
278292
samples generated for each :math:`x_i` is proportional to the number of samples
279293
which are not from the same class than :math:`x_i` in a given

imblearn/over_sampling/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._smote import KMeansSMOTE
1111
from ._smote import SVMSMOTE
1212
from ._smote import SMOTENC
13-
from ._smote import SLSMOTE
13+
from ._smote import SafeLevelSMOTE
1414

1515
__all__ = [
1616
"ADASYN",
@@ -20,5 +20,5 @@
2020
"BorderlineSMOTE",
2121
"SVMSMOTE",
2222
"SMOTENC",
23-
"SLSMOTE",
23+
"SafeLevelSMOTE",
2424
]

imblearn/over_sampling/_smote.py

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ class BorderlineSMOTE(BaseSMOTE):
284284
285285
SVMSMOTE : Over-sample using SVM-SMOTE variant.
286286
287+
KMeansSMOTE: Over-sample using KMeans-SMOTE variant.
288+
289+
SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant.
290+
291+
287292
ADASYN : Over-sample using ADASYN.
288293
289294
References
@@ -484,6 +489,10 @@ class SVMSMOTE(BaseSMOTE):
484489
485490
BorderlineSMOTE : Over-sample using Borderline-SMOTE.
486491
492+
KMeansSMOTE: Over-sample using KMeans-SMOTE variant.
493+
494+
SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant.
495+
487496
ADASYN : Over-sample using ADASYN.
488497
489498
References
@@ -695,6 +704,10 @@ class SMOTE(BaseSMOTE):
695704
696705
SVMSMOTE : Over-sample using the SVM-SMOTE variant.
697706
707+
KMeansSMOTE: Over-sample using KMeans-SMOTE variant.
708+
709+
SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant.
710+
698711
ADASYN : Over-sample using ADASYN.
699712
700713
References
@@ -864,6 +877,10 @@ class SMOTENC(SMOTE):
864877
865878
BorderlineSMOTE : Over-sample using Borderline-SMOTE variant.
866879
880+
KMeansSMOTE: Over-sample using KMeans-SMOTE variant.
881+
882+
SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant.
883+
867884
ADASYN : Over-sample using ADASYN.
868885
869886
References
@@ -1318,7 +1335,7 @@ def _fit_resample(self, X, y):
13181335
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
13191336
random_state=_random_state_docstring,
13201337
)
1321-
class SLSMOTE(BaseSMOTE):
1338+
class SafeLevelSMOTE(BaseSMOTE):
13221339
"""Class to perform over-sampling using safe-level SMOTE.
13231340
This is an implementation of the Safe-level-SMOTE described in [2]_.
13241341
@@ -1389,13 +1406,13 @@ class SLSMOTE(BaseSMOTE):
13891406
>>> from collections import Counter
13901407
>>> from sklearn.datasets import make_classification
13911408
>>> from imblearn.over_sampling import \
1392-
SLSMOTE # doctest: +NORMALIZE_WHITESPACE
1409+
SafeLevelSMOTE # doctest: +NORMALIZE_WHITESPACE
13931410
>>> X, y = make_classification(n_classes=2, class_sep=2,
13941411
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
13951412
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
13961413
>>> print('Original dataset shape %s' % Counter(y))
13971414
Original dataset shape Counter({{1: 900, 0: 100}})
1398-
>>> sm = SLSMOTE(random_state=42)
1415+
>>> sm = SafeLevelSMOTE(random_state=42)
13991416
>>> X_res, y_res = sm.fit_resample(X, y)
14001417
>>> print('Resampled dataset shape %s' % Counter(y_res))
14011418
Resampled dataset shape Counter({{0: 900, 1: 900}})
@@ -1415,7 +1432,7 @@ def __init__(self,
14151432

14161433
self.m_neighbors = m_neighbors
14171434

1418-
def _assign_sl(self, nn_estimator, samples, target_class, y):
1435+
def _assign_safe_levels(self, nn_estimator, samples, target_class, y):
14191436
'''
14201437
Assign the safe levels to the instances in the target class.
14211438
@@ -1444,8 +1461,8 @@ def _assign_sl(self, nn_estimator, samples, target_class, y):
14441461

14451462
x = nn_estimator.kneighbors(samples, return_distance=False)[:, 1:]
14461463
nn_label = (y[x] == target_class).astype(int)
1447-
sl = np.sum(nn_label, axis=1)
1448-
return sl
1464+
safe_levels = np.sum(nn_label, axis=1)
1465+
return safe_levels
14491466

14501467
def _validate_estimator(self):
14511468
super()._validate_estimator()
@@ -1466,28 +1483,30 @@ def _fit_resample(self, X, y):
14661483
X_class = _safe_indexing(X, target_class_indices)
14671484

14681485
self.nn_m_.fit(X)
1469-
sl = self._assign_sl(self.nn_m_, X_class, class_sample, y)
1486+
safe_levels = self._assign_safe_levels(
1487+
self.nn_m_, X_class, class_sample, y)
14701488

14711489
# filter the points in X_class that have safe level >0
14721490
# If safe level = 0, the point is not used to
14731491
# generate synthetic instances
1474-
X_safe_indices = np.flatnonzero(sl != 0)
1492+
X_safe_indices = np.flatnonzero(safe_levels != 0)
14751493
X_safe_class = _safe_indexing(X_class, X_safe_indices)
14761494

14771495
self.nn_k_.fit(X_class)
14781496
nns = self.nn_k_.kneighbors(X_safe_class,
14791497
return_distance=False)[:, 1:]
14801498

1481-
sl_safe_class = sl[X_safe_indices]
1482-
sl_nns = sl[nns]
1499+
sl_safe_class = safe_levels[X_safe_indices]
1500+
sl_nns = safe_levels[nns]
14831501
sl_safe_t = np.array([sl_safe_class]).transpose()
14841502
with np.errstate(divide='ignore'):
1485-
sl_ratio = np.divide(sl_safe_t, sl_nns)
1503+
safe_level_ratio = np.divide(sl_safe_t, sl_nns)
14861504

1487-
X_new, y_new = self._make_samples_sl(X_safe_class, y.dtype,
1488-
class_sample, X_class,
1489-
nns, n_samples, sl_ratio,
1490-
1.0)
1505+
X_new, y_new = self._make_samples_safelevel(X_safe_class, y.dtype,
1506+
class_sample, X_class,
1507+
nns, n_samples,
1508+
safe_level_ratio,
1509+
1.0)
14911510

14921511
if sparse.issparse(X_new):
14931512
X_resampled = sparse.vstack([X_resampled, X_new])
@@ -1497,8 +1516,8 @@ def _fit_resample(self, X, y):
14971516

14981517
return X_resampled, y_resampled
14991518

1500-
def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num,
1501-
n_samples, sl_ratio, step_size=1.):
1519+
def _make_samples_safelevel(self, X, y_dtype, y_type, nn_data, nn_num,
1520+
n_samples, safe_level_ratio, step_size=1.):
15021521
"""A support function that returns artificial samples using
15031522
safe-level SMOTE. It is similar to _make_samples method for SMOTE.
15041523
@@ -1524,7 +1543,7 @@ def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num,
15241543
n_samples : int
15251544
The number of samples to generate.
15261545
1527-
sl_ratio: ndarray, shape (n_samples_safe, k_nearest_neighbours)
1546+
safe_level_ratio: ndarray, shape (n_samples_safe, k_nearest_neighbours)
15281547
15291548
step_size : float, optional (default=1.)
15301549
The step size to create samples.
@@ -1546,8 +1565,8 @@ def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num,
15461565
size=n_samples)
15471566
rows = np.floor_divide(samples_indices, nn_num.shape[1])
15481567
cols = np.mod(samples_indices, nn_num.shape[1])
1549-
gap_arr = step_size * self._vgenerate_gap(sl_ratio)
1550-
gaps = gap_arr.flatten()[samples_indices]
1568+
gap_array = step_size * self._vgenerate_gap(safe_level_ratio)
1569+
gaps = gap_array.flatten()[samples_indices]
15511570

15521571
y_new = np.array([y_type] * n_samples, dtype=y_dtype)
15531572

@@ -1578,12 +1597,12 @@ def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num,
15781597
return X_new, y_new
15791598

15801599
def _generate_gap(self, a_ratio, rand_state=None):
1581-
""" generate gap according to sl_ratio, non-vectorized version.
1600+
""" generate gap according to safe_level_ratio, non-vectorized version.
15821601
15831602
Parameters
15841603
----------
15851604
a_ratio: float
1586-
sl_ratio of a single data point
1605+
safe_level_ratio of a single data point
15871606
15881607
rand_state: random state object or int
15891608
@@ -1603,28 +1622,30 @@ def _generate_gap(self, a_ratio, rand_state=None):
16031622
elif 0 < a_ratio < 1:
16041623
gap = random_state.uniform(1-a_ratio, 1)
16051624
else:
1606-
raise ValueError('sl_ratio should be nonegative')
1625+
raise ValueError('safe_level_ratio should be nonegative')
16071626
return gap
16081627

1609-
def _vgenerate_gap(self, sl_ratio):
1628+
def _vgenerate_gap(self, safe_level_ratio):
16101629
"""
1611-
generate gap according to sl_ratio, vectorized version of _generate_gap
1630+
generate gap according to safe_level_ratio, vectorized version
1631+
of _generate_gap
16121632
16131633
Parameters
16141634
-----------
1615-
sl_ratio: ndarray shape (n_samples_safe, k_nearest_neighbours)
1616-
sl_ratio of all instances with safe_level>0 in the specified
1617-
class
1635+
safe_level_ratio: ndarray shape (n_samples_safe, k_nearest_neighbours)
1636+
safe_level_ratio of all instances with safe_level>0 in the
1637+
specified class
16181638
16191639
Returns
16201640
------------
1621-
gap_arr: ndarray shape (n_samples_safe, k_nearest_neighbours)
1641+
gap_array: ndarray shape (n_samples_safe, k_nearest_neighbours)
16221642
the gap for all instances with safe_level>0 in the specified
16231643
class
16241644
16251645
"""
16261646
prng = check_random_state(self.random_state)
1627-
rand_state = prng.randint(sl_ratio.size+1, size=sl_ratio.shape)
1647+
rand_state = prng.randint(
1648+
safe_level_ratio.size+1, size=safe_level_ratio.shape)
16281649
vgap = np.vectorize(self._generate_gap)
1629-
gap_arr = vgap(sl_ratio, rand_state)
1630-
return gap_arr
1650+
gap_array = vgap(safe_level_ratio, rand_state)
1651+
return gap_array
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import pytest
22
import numpy as np
3+
from collections import Counter
34

45
from sklearn.neighbors import NearestNeighbors
56
from scipy import sparse
67

78
from sklearn.utils._testing import assert_allclose
89
from sklearn.utils._testing import assert_array_equal
910

10-
from imblearn.over_sampling import SLSMOTE
11+
from imblearn.over_sampling import SafeLevelSMOTE
1112

1213

1314
def data_np():
@@ -27,40 +28,53 @@ def data_sparse(format):
2728
"data",
2829
[data_np(), data_sparse('csr'), data_sparse('csc')]
2930
)
30-
def test_slsmote(data):
31+
def test_safelevel_smote(data):
3132
y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0,
3233
0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0])
3334
X, y = data
34-
slsmote = SLSMOTE(random_state=42)
35-
X_res, y_res = slsmote.fit_resample(X, y)
35+
safelevel_smote = SafeLevelSMOTE(random_state=42)
36+
X_res, y_res = safelevel_smote.fit_resample(X, y)
3637

3738
assert X_res.shape == (24, 2)
3839
assert_array_equal(y_res, y_gt)
3940

4041

41-
def test_slsmote_nn():
42+
def test_sl_smote_nn():
4243
X, y = data_np()
43-
slsmote = SLSMOTE(random_state=42)
44-
slsmote_nn = SLSMOTE(
44+
safelevel_smote = SafeLevelSMOTE(random_state=42)
45+
safelevel_smote_nn = SafeLevelSMOTE(
4546
random_state=42,
4647
k_neighbors=NearestNeighbors(n_neighbors=6),
4748
m_neighbors=NearestNeighbors(n_neighbors=11),
4849
)
4950

50-
X_res_1, y_res_1 = slsmote.fit_resample(X, y)
51-
X_res_2, y_res_2 = slsmote_nn.fit_resample(X, y)
51+
X_res_1, y_res_1 = safelevel_smote.fit_resample(X, y)
52+
X_res_2, y_res_2 = safelevel_smote_nn.fit_resample(X, y)
5253

5354
assert_allclose(X_res_1, X_res_2)
5455
assert_array_equal(y_res_1, y_res_2)
5556

5657

57-
def test_slsmote_pd():
58+
def test_sl_smote_pd():
5859
pd = pytest.importorskip("pandas")
5960
X, y = data_np()
6061
X_pd = pd.DataFrame(X)
61-
slsmote = SLSMOTE(random_state=42)
62-
X_res, y_res = slsmote.fit_resample(X, y)
63-
X_res_pd, y_res_pd = slsmote.fit_resample(X_pd, y)
62+
safelevel_smote = SafeLevelSMOTE(random_state=42)
63+
X_res, y_res = safelevel_smote.fit_resample(X, y)
64+
X_res_pd, y_res_pd = safelevel_smote.fit_resample(X_pd, y)
6465

6566
assert X_res_pd.tolist() == X_res.tolist()
6667
assert_allclose(y_res_pd, y_res)
68+
69+
70+
def test_sl_smote_multiclass():
71+
rng = np.random.RandomState(42)
72+
X = rng.randn(50, 2)
73+
y = np.array([0] * 10 + [1] * 15 + [2] * 25)
74+
safelevel_smote = SafeLevelSMOTE(random_state=42)
75+
X_res, y_res = safelevel_smote.fit_resample(X, y)
76+
77+
count_y_res = Counter(y_res)
78+
assert count_y_res[0] == 25
79+
assert count_y_res[1] == 25
80+
assert count_y_res[2] == 25

0 commit comments

Comments
 (0)