Skip to content

Commit bcc3069

Browse files
committed
added safe-level-smote method
1 parent 321b751 commit bcc3069

File tree

2 files changed

+326
-4
lines changed

2 files changed

+326
-4
lines changed

imblearn/over_sampling/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ._smote import KMeansSMOTE
1111
from ._smote import SVMSMOTE
1212
from ._smote import SMOTENC
13+
from ._smote import SLSMOTE
1314

1415
__all__ = [
1516
"ADASYN",
@@ -19,4 +20,5 @@
1920
"BorderlineSMOTE",
2021
"SVMSMOTE",
2122
"SMOTENC",
23+
"SLSMOTE",
2224
]

imblearn/over_sampling/_smote.py

Lines changed: 324 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,14 @@ def _fit_resample(self, X, y):
586586
n_generated_samples = int(fractions * (n_samples + 1))
587587
if np.count_nonzero(danger_bool) > 0:
588588
nns = self.nn_k_.kneighbors(
589-
_safe_indexing(support_vector, np.flatnonzero(danger_bool)),
589+
_safe_indexing(
590+
support_vector, np.flatnonzero(danger_bool)),
590591
return_distance=False,
591592
)[:, 1:]
592593

593594
X_new_1, y_new_1 = self._make_samples(
594-
_safe_indexing(support_vector, np.flatnonzero(danger_bool)),
595+
_safe_indexing(
596+
support_vector, np.flatnonzero(danger_bool)),
595597
y.dtype,
596598
class_sample,
597599
X_class,
@@ -602,12 +604,14 @@ def _fit_resample(self, X, y):
602604

603605
if np.count_nonzero(safety_bool) > 0:
604606
nns = self.nn_k_.kneighbors(
605-
_safe_indexing(support_vector, np.flatnonzero(safety_bool)),
607+
_safe_indexing(
608+
support_vector, np.flatnonzero(safety_bool)),
606609
return_distance=False,
607610
)[:, 1:]
608611

609612
X_new_2, y_new_2 = self._make_samples(
610-
_safe_indexing(support_vector, np.flatnonzero(safety_bool)),
613+
_safe_indexing(
614+
support_vector, np.flatnonzero(safety_bool)),
611615
y.dtype,
612616
class_sample,
613617
X_class,
@@ -1308,3 +1312,319 @@ def _fit_resample(self, X, y):
13081312
y_resampled = np.hstack((y_resampled, y_new))
13091313

13101314
return X_resampled, y_resampled
1315+
1316+
1317+
@Substitution(
1318+
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
1319+
random_state=_random_state_docstring,
1320+
)
1321+
class SLSMOTE(BaseSMOTE):
1322+
"""Class to perform over-sampling using safe-level SMOTE.
1323+
This is an implementation of the Safe-level-SMOTE described in [2]_.
1324+
1325+
Parameters
1326+
-----------
1327+
{sampling_strategy}
1328+
1329+
{random_state}
1330+
1331+
k_neighbors : int or object, optional (default=5)
1332+
If ``int``, number of nearest neighbours to used to construct synthetic
1333+
samples. If object, an estimator that inherits from
1334+
:class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to
1335+
find the k_neighbors.
1336+
1337+
m_neighbors : int or object, optional (default=10)
1338+
If ``int``, number of nearest neighbours to use to determine the safe
1339+
level of an instance. If object, an estimator that inherits from
1340+
:class:`sklearn.neighbors.base.KNeighborsMixin` that will be used
1341+
to find the m_neighbors.
1342+
1343+
n_jobs : int or None, optional (default=None)
1344+
Number of CPU cores used during the cross-validation loop.
1345+
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
1346+
``-1`` means using all processors. See
1347+
`Glossary <https://scikit-learn.org/stable/glossary.html#term-n-jobs>`_
1348+
for more details.
1349+
1350+
1351+
Notes
1352+
-----
1353+
See the original papers: [2]_ for more details.
1354+
1355+
Supports multi-class resampling. A one-vs.-rest scheme is used as
1356+
originally proposed in [1]_.
1357+
1358+
See also
1359+
--------
1360+
SMOTE : Over-sample using SMOTE.
1361+
1362+
SMOTENC : Over-sample using SMOTE for continuous and categorical features.
1363+
1364+
SVMSMOTE : Over-sample using SVM-SMOTE variant.
1365+
1366+
BorderlineSMOTE : Over-sample using Borderline-SMOTE.
1367+
1368+
ADASYN : Over-sample using ADASYN.
1369+
1370+
KMeansSMOTE: Over-sample using KMeans-SMOTE variant.
1371+
1372+
References
1373+
----------
1374+
.. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE:
1375+
synthetic minority over-sampling technique," Journal of artificial
1376+
intelligence research, 321-357, 2002.
1377+
1378+
.. [2] C. Bunkhumpornpat, K. Sinapiromsaran, C. Lursinsap, "Safe-level-
1379+
SMOTE: Safe-level-synthetic minority over-sampling technique for
1380+
handling the class imbalanced problem," In: Theeramunkong T.,
1381+
Kijsirikul B., Cercone N., Ho TB. (eds) Advances in Knowledge Discovery
1382+
and Data Mining. PAKDD 2009. Lecture Notes in Computer Science,
1383+
vol 5476. Springer, Berlin, Heidelberg, 475-482, 2009.
1384+
1385+
1386+
Examples
1387+
--------
1388+
1389+
>>> from collections import Counter
1390+
>>> from sklearn.datasets import make_classification
1391+
>>> from imblearn.over_sampling import \
1392+
SLSMOTE # doctest: +NORMALIZE_WHITESPACE
1393+
>>> X, y = make_classification(n_classes=2, class_sep=2,
1394+
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
1395+
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
1396+
>>> print('Original dataset shape %s' % Counter(y))
1397+
Original dataset shape Counter({{1: 900, 0: 100}})
1398+
>>> sm = SLSMOTE(random_state=42)
1399+
>>> X_res, y_res = sm.fit_resample(X, y)
1400+
>>> print('Resampled dataset shape %s' % Counter(y_res))
1401+
Resampled dataset shape Counter({{0: 900, 1: 900}})
1402+
1403+
"""
1404+
1405+
def __init__(self,
1406+
sampling_strategy='auto',
1407+
random_state=None,
1408+
k_neighbors=5,
1409+
m_neighbors=10,
1410+
n_jobs=None):
1411+
1412+
super().__init__(sampling_strategy=sampling_strategy,
1413+
random_state=random_state, k_neighbors=k_neighbors,
1414+
n_jobs=n_jobs)
1415+
1416+
self.m_neighbors = m_neighbors
1417+
1418+
def _assign_sl(self, nn_estimator, samples, target_class, y):
1419+
'''
1420+
Assign the safe levels to the instances in the target class.
1421+
1422+
Parameters
1423+
----------
1424+
nn_estimator : estimator
1425+
An estimator that inherits from
1426+
:class:`sklearn.neighbors.base.KNeighborsMixin`. It gets the
1427+
nearest neighbors that are used to determine the safe levels.
1428+
1429+
samples : {array-like, sparse matrix}, shape (n_samples, n_features)
1430+
The samples to which the safe levels are assigned.
1431+
1432+
target_class : int or str
1433+
The target corresponding class being over-sampled.
1434+
1435+
y : array-like, shape (n_samples,)
1436+
The true label in order to calculate the safe levels.
1437+
1438+
Returns
1439+
-------
1440+
output : ndarray, shape (n_samples,)
1441+
A ndarray where the values refer to the safe level of the
1442+
instances in the target class.
1443+
'''
1444+
1445+
x = nn_estimator.kneighbors(samples, return_distance=False)[:, 1:]
1446+
nn_label = (y[x] == target_class).astype(int)
1447+
sl = np.sum(nn_label, axis=1)
1448+
return sl
1449+
1450+
def _validate_estimator(self):
1451+
super()._validate_estimator()
1452+
self.nn_m_ = check_neighbors_object('m_neighbors', self.m_neighbors,
1453+
additional_neighbor=1)
1454+
self.nn_m_.set_params(**{"n_jobs": self.n_jobs})
1455+
1456+
def _fit_resample(self, X, y):
1457+
self._validate_estimator()
1458+
1459+
X_resampled = X.copy()
1460+
y_resampled = y.copy()
1461+
1462+
for class_sample, n_samples in self.sampling_strategy_.items():
1463+
if n_samples == 0:
1464+
continue
1465+
target_class_indices = np.flatnonzero(y == class_sample)
1466+
X_class = _safe_indexing(X, target_class_indices)
1467+
1468+
self.nn_m_.fit(X)
1469+
sl = self._assign_sl(self.nn_m_, X_class, class_sample, y)
1470+
1471+
# filter the points in X_class that have safe level >0
1472+
# If safe level = 0, the point is not used to
1473+
# generate synthetic instances
1474+
X_safe_indices = np.flatnonzero(sl != 0)
1475+
X_safe_class = _safe_indexing(X_class, X_safe_indices)
1476+
1477+
self.nn_k_.fit(X_class)
1478+
nns = self.nn_k_.kneighbors(X_safe_class,
1479+
return_distance=False)[:, 1:]
1480+
1481+
sl_safe_class = sl[X_safe_indices]
1482+
sl_nns = sl[nns]
1483+
sl_safe_t = np.array([sl_safe_class]).transpose()
1484+
with np.errstate(divide='ignore'):
1485+
sl_ratio = np.divide(sl_safe_t, sl_nns)
1486+
1487+
X_new, y_new = self._make_samples_sl(X_safe_class, y.dtype,
1488+
class_sample, X_class,
1489+
nns, n_samples, sl_ratio,
1490+
1.0)
1491+
1492+
if sparse.issparse(X_new):
1493+
X_resampled = sparse.vstack([X_resampled, X_new])
1494+
else:
1495+
X_resampled = np.vstack((X_resampled, X_new))
1496+
y_resampled = np.hstack((y_resampled, y_new))
1497+
1498+
return X_resampled, y_resampled
1499+
1500+
def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num,
1501+
n_samples, sl_ratio, step_size=1.):
1502+
"""A support function that returns artificial samples using
1503+
safe-level SMOTE. It is similar to _make_samples method for SMOTE.
1504+
1505+
Parameters
1506+
----------
1507+
X : {array-like, sparse matrix}, shape (n_samples_safe, n_features)
1508+
Points from which the points will be created.
1509+
1510+
y_dtype : dtype
1511+
The data type of the targets.
1512+
1513+
y_type : str or int
1514+
The minority target value, just so the function can return the
1515+
target values for the synthetic variables with correct length in
1516+
a clear format.
1517+
1518+
nn_data : ndarray, shape (n_samples_all, n_features)
1519+
Data set carrying all the neighbours to be used
1520+
1521+
nn_num : ndarray, shape (n_samples_safe, k_nearest_neighbours)
1522+
The nearest neighbours of each sample in `nn_data`.
1523+
1524+
n_samples : int
1525+
The number of samples to generate.
1526+
1527+
sl_ratio: ndarray, shape (n_samples_safe, k_nearest_neighbours)
1528+
1529+
step_size : float, optional (default=1.)
1530+
The step size to create samples.
1531+
1532+
1533+
Returns
1534+
-------
1535+
X_new : {ndarray, sparse matrix}, shape (n_samples_new, n_features)
1536+
Synthetically generated samples using the safe-level method.
1537+
1538+
y_new : ndarray, shape (n_samples_new,)
1539+
Target values for synthetic samples.
1540+
1541+
"""
1542+
1543+
random_state = check_random_state(self.random_state)
1544+
samples_indices = random_state.randint(low=0,
1545+
high=len(nn_num.flatten()),
1546+
size=n_samples)
1547+
rows = np.floor_divide(samples_indices, nn_num.shape[1])
1548+
cols = np.mod(samples_indices, nn_num.shape[1])
1549+
gap_arr = step_size * self._vgenerate_gap(sl_ratio)
1550+
gaps = gap_arr.flatten()[samples_indices]
1551+
1552+
y_new = np.array([y_type] * n_samples, dtype=y_dtype)
1553+
1554+
if sparse.issparse(X):
1555+
row_indices, col_indices, samples = [], [], []
1556+
for i, (row, col, gap) in enumerate(zip(rows, cols, gaps)):
1557+
if X[row].nnz:
1558+
sample = self._generate_sample(
1559+
X, nn_data, nn_num, row, col, gap)
1560+
row_indices += [i] * len(sample.indices)
1561+
col_indices += sample.indices.tolist()
1562+
samples += sample.data.tolist()
1563+
return (
1564+
sparse.csr_matrix(
1565+
(samples, (row_indices, col_indices)),
1566+
[len(samples_indices), X.shape[1]],
1567+
dtype=X.dtype,
1568+
),
1569+
y_new,
1570+
)
1571+
1572+
else:
1573+
X_new = np.zeros((n_samples, X.shape[1]), dtype=X.dtype)
1574+
for i, (row, col, gap) in enumerate(zip(rows, cols, gaps)):
1575+
X_new[i] = self._generate_sample(X, nn_data, nn_num,
1576+
row, col, gap)
1577+
1578+
return X_new, y_new
1579+
1580+
def _generate_gap(self, a_ratio, rand_state=None):
1581+
""" generate gap according to sl_ratio, non-vectorized version.
1582+
1583+
Parameters
1584+
----------
1585+
a_ratio: float
1586+
sl_ratio of a single data point
1587+
1588+
rand_state: random state object or int
1589+
1590+
1591+
Returns
1592+
------------
1593+
gap: float
1594+
a number between 0 and 1
1595+
1596+
"""
1597+
1598+
random_state = check_random_state(rand_state)
1599+
if np.isinf(a_ratio):
1600+
gap = 0
1601+
elif a_ratio >= 1:
1602+
gap = random_state.uniform(0, 1/a_ratio)
1603+
elif 0 < a_ratio < 1:
1604+
gap = random_state.uniform(1-a_ratio, 1)
1605+
else:
1606+
raise ValueError('sl_ratio should be nonegative')
1607+
return gap
1608+
1609+
def _vgenerate_gap(self, sl_ratio):
1610+
"""
1611+
generate gap according to sl_ratio, vectorized version of _generate_gap
1612+
1613+
Parameters
1614+
-----------
1615+
sl_ratio: ndarray shape (n_samples_safe, k_nearest_neighbours)
1616+
sl_ratio of all instances with safe_level>0 in the specified
1617+
class
1618+
1619+
Returns
1620+
------------
1621+
gap_arr: ndarray shape (n_samples_safe, k_nearest_neighbours)
1622+
the gap for all instances with safe_level>0 in the specified
1623+
class
1624+
1625+
"""
1626+
prng = check_random_state(self.random_state)
1627+
rand_state = prng.randint(sl_ratio.size+1, size=sl_ratio.shape)
1628+
vgap = np.vectorize(self._generate_gap)
1629+
gap_arr = vgap(sl_ratio, rand_state)
1630+
return gap_arr

0 commit comments

Comments
 (0)