Skip to content

[MRG] ENH: K-Means SMOTE implementation #435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Jun 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
dba11ac
Initial K-Means SMOTE commit.
StephanHeijl Jun 21, 2018
d54ffc2
PEP8, PyFlakes fixes, corrected paper reference.
StephanHeijl Jun 21, 2018
4a9b990
Added examples.
StephanHeijl Jun 21, 2018
5dd0526
Added error when clustering fails to find a cluster with sufficient s…
StephanHeijl Jun 21, 2018
642e62e
Added test for wrong hyperparameters
StephanHeijl Jun 22, 2018
fd663f1
Save an indexing operation if cluster_class_mean is insufficient.
StephanHeijl Jun 22, 2018
0ef982b
Simplified vstack function call.
StephanHeijl Jun 22, 2018
4c37593
Resolved stacking error
StephanHeijl Jun 22, 2018
efb6a75
Added extra arguments for kmeans sampling, addressed suggestions by F…
StephanHeijl Jul 1, 2018
131e3b3
Resolved errors and warnings
StephanHeijl Jul 1, 2018
7de5951
Resolve PEP8 style issues
StephanHeijl Jul 1, 2018
7029266
Added special k-means cases and tests.
StephanHeijl Jul 12, 2018
7696e44
solve conflicts
glemaitre Aug 22, 2018
f99433b
Merge remote-tracking branch 'origin/master' into StephanHeijl-kmeans…
glemaitre Aug 27, 2018
d65cea3
Merge remote-tracking branch 'origin/master' into pr/StephanHeijl/435
glemaitre Sep 12, 2018
c5ab59c
Removed KMeans specific code
StephanHeijl Mar 3, 2019
25e8ef7
Merge branch 'master' into kmeans-smote
StephanHeijl Mar 3, 2019
950df34
Restored KMeansSMOTE
StephanHeijl Mar 3, 2019
2851b7e
Resolved KMeansSmote errors
StephanHeijl Mar 4, 2019
25cd90b
Resolved python2.7 errors
StephanHeijl Mar 5, 2019
750decc
improved code coverage
StephanHeijl Mar 5, 2019
b2b766d
Resolved test error resulting from coverage improvement
StephanHeijl Mar 5, 2019
0358d0f
Made custom kmeans test deterministic
StephanHeijl Mar 5, 2019
01f31d0
Removed superfluous check
StephanHeijl Mar 5, 2019
7aa6f86
Change test to use custom KMeans instance (MiniBatchKmeans was default)
StephanHeijl Mar 7, 2019
1f34912
Resolved PEP8 issues
StephanHeijl Mar 10, 2019
05d3f40
Merge branch 'kmeans-smote' of github.com:StephanHeijl/imbalanced-lea…
StephanHeijl Mar 10, 2019
6129fbf
Fixed using the wrong variable name
StephanHeijl Mar 10, 2019
9537ec9
Fixed error in _make_samples call, resolves mediocre sample selection.
StephanHeijl Apr 1, 2019
b6fbca4
Updated KMeansSMOTE tests
StephanHeijl Apr 1, 2019
ca9b541
Clarified RuntimeError with solution to problem
StephanHeijl Apr 1, 2019
1b4dfd2
Adjusted documentation according to @chkoar's review.
StephanHeijl May 6, 2019
367f3a0
Slightly adjusted test to 'fail' for regular SMOTE.
StephanHeijl May 6, 2019
7d79475
Merge branch 'master' into kmeans-smote
StephanHeijl May 6, 2019
4a414c3
Fix expected print output
StephanHeijl May 6, 2019
d9fa137
Added ratio back to pass check_samplers_ratio_fit_resample test
StephanHeijl May 6, 2019
cf1b1fe
Added KMeansSMOTE to DONT_SUPPORT_RATIO and removed space from print
StephanHeijl May 6, 2019
9842573
Merge remote-tracking branch 'origin/master' into kmeans-smote
glemaitre Jun 12, 2019
0c4dd16
cleaning
glemaitre Jun 12, 2019
f4ec980
DOC: add an entry in documentation
glemaitre Jun 12, 2019
c3a1502
DOC: add entry in API documentation
glemaitre Jun 12, 2019
032842e
DOC: add whats new entry
glemaitre Jun 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Prototype selection

over_sampling.ADASYN
over_sampling.BorderlineSMOTE
over_sampling.KMeansSMOTE
over_sampling.RandomOverSampler
over_sampling.SMOTE
over_sampling.SMOTENC
Expand Down
12 changes: 10 additions & 2 deletions doc/over_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ nearest neighbors class. Those variants are presented in the figure below.
:align: center


The :class:`BorderlineSMOTE` [HWB2005]_ and :class:`SVMSMOTE` [NCK2009]_ offer
some variant of the SMOTE algorithm::
The :class:`BorderlineSMOTE` [HWB2005]_, :class:`SVMSMOTE` [NCK2009]_, and
:class:`KMeansSMOTE` [LDB2017]_ offer some variant of the SMOTE algorithm::

>>> from imblearn.over_sampling import BorderlineSMOTE
>>> X_resampled, y_resampled = BorderlineSMOTE().fit_resample(X, y)
Expand Down Expand Up @@ -209,6 +209,10 @@ other extra interpolation.
Knowledge Engineering and Soft Data Paradigms, 3(1), pp.4-21,
2009.

.. [LDB2017] Felix Last, Georgios Douzas, Fernando Bacao, "Oversampling for
Imbalanced Learning Based on K-Means and SMOTE"
https://arxiv.org/abs/1711.00837

Mathematical formulation
========================

Expand Down Expand Up @@ -266,6 +270,10 @@ parameter of the SVM classifier allows to select more or less support vectors.
For both borderline and SVM SMOTE, a neighborhood is defined using the
parameter ``m_neighbors`` to decide if a sample is in danger, safe, or noise.

**KMeans** SMOTE --- cf. to :class:`KMeansSMOTE` --- uses a KMeans clustering
method before to apply SMOTE. The clustering will group samples together and
generate new samples depending of the cluster density.

ADASYN works similarly to the regular SMOTE. However, the number of
samples generated for each :math:`x_i` is proportional to the number of samples
which are not from the same class than :math:`x_i` in a given
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v0.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ Enhancement
and issue template showing how to print system and dependency information
from the command line. :pr:`557` by :user:`Alexander L. Hayes <batflyer>`.

- Add :class:`imblearn.over_sampling.KMeansSMOTE` which is an over-sampler
clustering points before to apply SMOTE.
:pr:`435` by :user:`Stephan Heijl <StephanHeijl>`.

Maintenance
...........

Expand Down
14 changes: 10 additions & 4 deletions examples/over-sampling/plot_comparison_over_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from imblearn.pipeline import make_pipeline
from imblearn.over_sampling import ADASYN
from imblearn.over_sampling import SMOTE, BorderlineSMOTE, SVMSMOTE, SMOTENC
from imblearn.over_sampling import (SMOTE, BorderlineSMOTE, SVMSMOTE, SMOTENC,
KMeansSMOTE)
from imblearn.over_sampling import RandomOverSampler
from imblearn.base import BaseSampler

Expand Down Expand Up @@ -204,18 +205,23 @@ def _fit_resample(self, X, y):
# SMOTE proposes several variants by identifying specific samples to consider
# during the resampling. The borderline version will detect which point to
# select which are in the border between two classes. The SVM version will use
# the support vectors found using an SVM algorithm to create new samples.
# the support vectors found using an SVM algorithm to create new sample while
# the KMeans version will make a clustering before to generate samples in each
# cluster independently depending each cluster density.

fig, ((ax1, ax2), (ax3, ax4),
(ax5, ax6), (ax7, ax8)) = plt.subplots(4, 2, figsize=(15, 30))
(ax5, ax6), (ax7, ax8),
(ax9, ax10)) = plt.subplots(5, 2, figsize=(15, 30))
X, y = create_dataset(n_samples=5000, weights=(0.01, 0.05, 0.94),
class_sep=0.8)

ax_arr = ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8))

ax_arr = ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8), (ax9, ax10))
for ax, sampler in zip(ax_arr,
(SMOTE(random_state=0),
BorderlineSMOTE(random_state=0, kind='borderline-1'),
BorderlineSMOTE(random_state=0, kind='borderline-2'),
KMeansSMOTE(random_state=0),
SVMSMOTE(random_state=0))):
clf = make_pipeline(sampler, LinearSVC())
clf.fit(X, y)
Expand Down
3 changes: 2 additions & 1 deletion imblearn/over_sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from ._random_over_sampler import RandomOverSampler
from ._smote import SMOTE
from ._smote import BorderlineSMOTE
from ._smote import KMeansSMOTE
from ._smote import SVMSMOTE
from ._smote import SMOTENC

__all__ = ['ADASYN', 'RandomOverSampler',
__all__ = ['ADASYN', 'RandomOverSampler', 'KMeansSMOTE',
'SMOTE', 'BorderlineSMOTE', 'SVMSMOTE', 'SMOTENC']
236 changes: 236 additions & 0 deletions imblearn/over_sampling/_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import division

import math
import types
import warnings
from collections import Counter
Expand All @@ -16,6 +17,8 @@
from scipy import sparse

from sklearn.base import clone
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import OneHotEncoder
from sklearn.svm import SVC
from sklearn.utils import check_random_state
Expand Down Expand Up @@ -1090,3 +1093,236 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step):
sample[start_idx + col_sel] = 1

return sparse.csr_matrix(sample) if sparse.issparse(X) else sample


@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
random_state=_random_state_docstring)
class KMeansSMOTE(BaseSMOTE):
"""Apply a KMeans clustering before to over-sample using SMOTE.

This is an implementation of the algorithm described in [1]_.

Read more in the :ref:`User Guide <smote_adasyn>`.

Parameters
----------
{sampling_strategy}

{random_state}

k_neighbors : int or object, optional (default=2)
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.

n_jobs : int, optional (default=1)
The number of threads to open if possible.

kmeans_estimator : int or object, optional (default=MiniBatchKMeans())
A KMeans instance or the number of clusters to be used. By default,
we used a :class:`sklearn.cluster.MiniBatchKMeans` which tend to be
better with large number of samples.

cluster_balance_threshold : str or float, optional (default="auto")
The threshold at which a cluster is called balanced and where samples
of the class selected for SMOTE will be oversampled. If "auto", this
will be determined by the ratio for each class, or it can be set
manually.

density_exponent : str or float, optional (default="auto")
This exponent is used to determine the density of a cluster. Leaving
this to "auto" will use a feature-length based exponent.

Attributes
----------
kmeans_estimator_ : estimator
The fitted clustering method used before to apply SMOTE.

nn_k_ : estimator
The fitted k-NN estimator used in SMOTE.

cluster_balance_threshold_ : float
The threshold used during ``fit`` for calling a cluster balanced.

References
----------
.. [1] Felix Last, Georgios Douzas, Fernando Bacao, "Oversampling for
Imbalanced Learning Based on K-Means and SMOTE"
https://arxiv.org/abs/1711.00837

Examples
--------

>>> import numpy as np
>>> from imblearn.over_sampling import KMeansSMOTE
>>> from sklearn.datasets import make_blobs
>>> blobs = [100, 800, 100]
>>> X, y = make_blobs(blobs, centers=[(-10, 0), (0,0), (10, 0)])
>>> # Add a single 0 sample in the middle blob
>>> X = np.concatenate([X, [[0, 0]]])
>>> y = np.append(y, 0)
>>> # Make this a binary classification problem
>>> y = y == 1
>>> sm = KMeansSMOTE(random_state=42)
>>> X_res, y_res = sm.fit_resample(X, y)
>>> # Find the number of new samples in the middle blob
>>> n_res_in_middle = ((X_res[:, 0] > -5) & (X_res[:, 0] < 5)).sum()
>>> print("Samples in the middle blob: %s" % n_res_in_middle)
Samples in the middle blob: 801
>>> print("Middle blob unchanged: %s" % (n_res_in_middle == blobs[1] + 1))
Middle blob unchanged: True
>>> print("More 0 samples: %s" % ((y_res == 0).sum() > (y == 0).sum()))
More 0 samples: True

"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the end of the docstring we could add the reference of the paper and after that an interactive example. Check here for instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an interactive example that specifically demonstrates the utility of the KMeansSmote class, 3 blobs, with the positive class in the middle and the negative classes on the outside and a single negative sample in the middle blob. The example shows that after resampling no new samples are added in the middle blob. Inspired by the following toy problem:
image

def __init__(self,
sampling_strategy='auto',
random_state=None,
k_neighbors=2,
n_jobs=1,
kmeans_estimator=None,
cluster_balance_threshold="auto",
density_exponent="auto"):
super().__init__(
sampling_strategy=sampling_strategy, random_state=random_state,
k_neighbors=k_neighbors, n_jobs=n_jobs)
self.kmeans_estimator = kmeans_estimator
self.cluster_balance_threshold = cluster_balance_threshold
self.density_exponent = density_exponent

def _validate_estimator(self):
super()._validate_estimator()
if self.kmeans_estimator is None:
self.kmeans_estimator_ = MiniBatchKMeans(
random_state=self.random_state)
elif isinstance(self.kmeans_estimator, int):
self.kmeans_estimator_ = MiniBatchKMeans(
n_clusters=self.kmeans_estimator,
random_state=self.random_state)
else:
self.kmeans_estimator_ = clone(self.kmeans_estimator)

# validate the parameters
for param_name in ('cluster_balance_threshold', 'density_exponent'):
param = getattr(self, param_name)
if isinstance(param, str) and param != 'auto':
raise ValueError(
"'{}' should be 'auto' when a string is passed. "
"Got {} instead.".format(param_name, repr(param))
)

self.cluster_balance_threshold_ = (
self.cluster_balance_threshold
if self.kmeans_estimator_.n_clusters != 1 else -np.inf
)


def _find_cluster_sparsity(self, X):
"""Compute the cluster sparsity."""
euclidean_distances = pairwise_distances(X, metric="euclidean",
n_jobs=self.n_jobs)
# negate diagonal elements
for ind in range(X.shape[0]):
euclidean_distances[ind, ind] = 0

non_diag_elements = (X.shape[0] ** 2) - X.shape[0]
mean_distance = euclidean_distances.sum() / non_diag_elements
exponent = (math.log(X.shape[0], 1.6) ** 1.8 * 0.16
if self.density_exponent == 'auto'
else self.density_exponent)
return (mean_distance ** exponent) / X.shape[0]

# FIXME: rename _sample -> _fit_resample in 0.6
def _fit_resample(self, X, y):
return self._sample(X, y)

def _sample(self, X, y):
self._validate_estimator()
X_resampled = X.copy()
y_resampled = y.copy()
total_inp_samples = sum(self.sampling_strategy_.values())

for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
continue

# target_class_indices = np.flatnonzero(y == class_sample)
# X_class = safe_indexing(X, target_class_indices)

X_clusters = self.kmeans_estimator_.fit_predict(X)
valid_clusters = []
cluster_sparsities = []

# identify cluster which are answering the requirements
for cluster_idx in range(self.kmeans_estimator_.n_clusters):

cluster_mask = np.flatnonzero(X_clusters == cluster_idx)
X_cluster = safe_indexing(X, cluster_mask)
y_cluster = safe_indexing(y, cluster_mask)

cluster_class_mean = (y_cluster == class_sample).mean()

if self.cluster_balance_threshold_ == "auto":
balance_threshold = n_samples / total_inp_samples / 2
else:
balance_threshold = self.cluster_balance_threshold_

# the cluster is already considered balanced
if cluster_class_mean < balance_threshold:
continue

# not enough samples to apply SMOTE
anticipated_samples = cluster_class_mean * X_cluster.shape[0]
if anticipated_samples < self.nn_k_.n_neighbors:
continue

X_cluster_class = safe_indexing(
X_cluster, np.flatnonzero(y_cluster == class_sample)
)

valid_clusters.append(cluster_mask)
cluster_sparsities.append(
self._find_cluster_sparsity(X_cluster_class)
)

cluster_sparsities = np.array(cluster_sparsities)
cluster_weights = cluster_sparsities / cluster_sparsities.sum()

if not valid_clusters:
raise RuntimeError(
"No clusters found with sufficient samples of "
"class {}. Try lowering the cluster_balance_threshold or "
"or increasing the number of "
"clusters.".format(class_sample))

for valid_cluster_idx, valid_cluster in enumerate(valid_clusters):
X_cluster = safe_indexing(X, valid_cluster)
y_cluster = safe_indexing(y, valid_cluster)

X_cluster_class = safe_indexing(
X_cluster, np.flatnonzero(y_cluster == class_sample)
)

self.nn_k_.fit(X_cluster_class)
nns = self.nn_k_.kneighbors(X_cluster_class,
return_distance=False)[:, 1:]

cluster_n_samples = int(math.ceil(
n_samples * cluster_weights[valid_cluster_idx])
)

X_new, y_new = self._make_samples(X_cluster_class,
y.dtype,
class_sample,
X_cluster_class,
nns,
cluster_n_samples,
1.0)

stack = [np.vstack, sparse.vstack][int(sparse.issparse(X_new))]
X_resampled = stack((X_resampled, X_new))
y_resampled = np.hstack((y_resampled, y_new))

return X_resampled, y_resampled
Loading