diff --git a/doc/api.rst b/doc/api.rst index 750c402f8..e98dfe47b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -71,6 +71,7 @@ Prototype selection over_sampling.ADASYN over_sampling.BorderlineSMOTE + over_sampling.KMeansSMOTE over_sampling.RandomOverSampler over_sampling.SMOTE over_sampling.SMOTENC diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 2d78825cb..6159e925b 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -152,8 +152,8 @@ nearest neighbors class. Those variants are presented in the figure below. :align: center -The :class:`BorderlineSMOTE` [HWB2005]_ and :class:`SVMSMOTE` [NCK2009]_ offer -some variant of the SMOTE algorithm:: +The :class:`BorderlineSMOTE` [HWB2005]_, :class:`SVMSMOTE` [NCK2009]_, and +:class:`KMeansSMOTE` [LDB2017]_ offer some variant of the SMOTE algorithm:: >>> from imblearn.over_sampling import BorderlineSMOTE >>> X_resampled, y_resampled = BorderlineSMOTE().fit_resample(X, y) @@ -209,6 +209,10 @@ other extra interpolation. Knowledge Engineering and Soft Data Paradigms, 3(1), pp.4-21, 2009. + .. [LDB2017] Felix Last, Georgios Douzas, Fernando Bacao, "Oversampling for + Imbalanced Learning Based on K-Means and SMOTE" + https://arxiv.org/abs/1711.00837 + Mathematical formulation ======================== @@ -266,6 +270,10 @@ parameter of the SVM classifier allows to select more or less support vectors. For both borderline and SVM SMOTE, a neighborhood is defined using the parameter ``m_neighbors`` to decide if a sample is in danger, safe, or noise. +**KMeans** SMOTE --- cf. to :class:`KMeansSMOTE` --- uses a KMeans clustering +method before to apply SMOTE. The clustering will group samples together and +generate new samples depending of the cluster density. + ADASYN works similarly to the regular SMOTE. However, the number of samples generated for each :math:`x_i` is proportional to the number of samples which are not from the same class than :math:`x_i` in a given diff --git a/doc/whats_new/v0.5.rst b/doc/whats_new/v0.5.rst index 6ff15db79..a1a20baa6 100644 --- a/doc/whats_new/v0.5.rst +++ b/doc/whats_new/v0.5.rst @@ -37,6 +37,10 @@ Enhancement and issue template showing how to print system and dependency information from the command line. :pr:`557` by :user:`Alexander L. Hayes `. +- Add :class:`imblearn.over_sampling.KMeansSMOTE` which is an over-sampler + clustering points before to apply SMOTE. + :pr:`435` by :user:`Stephan Heijl `. + Maintenance ........... diff --git a/examples/over-sampling/plot_comparison_over_sampling.py b/examples/over-sampling/plot_comparison_over_sampling.py index 97d8bcc4a..003e4e2c6 100644 --- a/examples/over-sampling/plot_comparison_over_sampling.py +++ b/examples/over-sampling/plot_comparison_over_sampling.py @@ -21,7 +21,8 @@ from imblearn.pipeline import make_pipeline from imblearn.over_sampling import ADASYN -from imblearn.over_sampling import SMOTE, BorderlineSMOTE, SVMSMOTE, SMOTENC +from imblearn.over_sampling import (SMOTE, BorderlineSMOTE, SVMSMOTE, SMOTENC, + KMeansSMOTE) from imblearn.over_sampling import RandomOverSampler from imblearn.base import BaseSampler @@ -204,18 +205,23 @@ def _fit_resample(self, X, y): # SMOTE proposes several variants by identifying specific samples to consider # during the resampling. The borderline version will detect which point to # select which are in the border between two classes. The SVM version will use -# the support vectors found using an SVM algorithm to create new samples. +# the support vectors found using an SVM algorithm to create new sample while +# the KMeans version will make a clustering before to generate samples in each +# cluster independently depending each cluster density. fig, ((ax1, ax2), (ax3, ax4), - (ax5, ax6), (ax7, ax8)) = plt.subplots(4, 2, figsize=(15, 30)) + (ax5, ax6), (ax7, ax8), + (ax9, ax10)) = plt.subplots(5, 2, figsize=(15, 30)) X, y = create_dataset(n_samples=5000, weights=(0.01, 0.05, 0.94), class_sep=0.8) -ax_arr = ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8)) + +ax_arr = ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8), (ax9, ax10)) for ax, sampler in zip(ax_arr, (SMOTE(random_state=0), BorderlineSMOTE(random_state=0, kind='borderline-1'), BorderlineSMOTE(random_state=0, kind='borderline-2'), + KMeansSMOTE(random_state=0), SVMSMOTE(random_state=0))): clf = make_pipeline(sampler, LinearSVC()) clf.fit(X, y) diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index 9cd63ac87..63abf3dc0 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -7,8 +7,9 @@ from ._random_over_sampler import RandomOverSampler from ._smote import SMOTE from ._smote import BorderlineSMOTE +from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC -__all__ = ['ADASYN', 'RandomOverSampler', +__all__ = ['ADASYN', 'RandomOverSampler', 'KMeansSMOTE', 'SMOTE', 'BorderlineSMOTE', 'SVMSMOTE', 'SMOTENC'] diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index bf2cb41db..a7109ab48 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -8,6 +8,7 @@ from __future__ import division +import math import types import warnings from collections import Counter @@ -16,6 +17,8 @@ from scipy import sparse from sklearn.base import clone +from sklearn.cluster import MiniBatchKMeans +from sklearn.metrics import pairwise_distances from sklearn.preprocessing import OneHotEncoder from sklearn.svm import SVC from sklearn.utils import check_random_state @@ -1090,3 +1093,236 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): sample[start_idx + col_sel] = 1 return sparse.csr_matrix(sample) if sparse.issparse(X) else sample + + +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + random_state=_random_state_docstring) +class KMeansSMOTE(BaseSMOTE): + """Apply a KMeans clustering before to over-sample using SMOTE. + + This is an implementation of the algorithm described in [1]_. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + {sampling_strategy} + + {random_state} + + k_neighbors : int or object, optional (default=2) + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. + + n_jobs : int, optional (default=1) + The number of threads to open if possible. + + kmeans_estimator : int or object, optional (default=MiniBatchKMeans()) + A KMeans instance or the number of clusters to be used. By default, + we used a :class:`sklearn.cluster.MiniBatchKMeans` which tend to be + better with large number of samples. + + cluster_balance_threshold : str or float, optional (default="auto") + The threshold at which a cluster is called balanced and where samples + of the class selected for SMOTE will be oversampled. If "auto", this + will be determined by the ratio for each class, or it can be set + manually. + + density_exponent : str or float, optional (default="auto") + This exponent is used to determine the density of a cluster. Leaving + this to "auto" will use a feature-length based exponent. + + Attributes + ---------- + kmeans_estimator_ : estimator + The fitted clustering method used before to apply SMOTE. + + nn_k_ : estimator + The fitted k-NN estimator used in SMOTE. + + cluster_balance_threshold_ : float + The threshold used during ``fit`` for calling a cluster balanced. + + References + ---------- + .. [1] Felix Last, Georgios Douzas, Fernando Bacao, "Oversampling for + Imbalanced Learning Based on K-Means and SMOTE" + https://arxiv.org/abs/1711.00837 + + Examples + -------- + + >>> import numpy as np + >>> from imblearn.over_sampling import KMeansSMOTE + >>> from sklearn.datasets import make_blobs + >>> blobs = [100, 800, 100] + >>> X, y = make_blobs(blobs, centers=[(-10, 0), (0,0), (10, 0)]) + >>> # Add a single 0 sample in the middle blob + >>> X = np.concatenate([X, [[0, 0]]]) + >>> y = np.append(y, 0) + >>> # Make this a binary classification problem + >>> y = y == 1 + >>> sm = KMeansSMOTE(random_state=42) + >>> X_res, y_res = sm.fit_resample(X, y) + >>> # Find the number of new samples in the middle blob + >>> n_res_in_middle = ((X_res[:, 0] > -5) & (X_res[:, 0] < 5)).sum() + >>> print("Samples in the middle blob: %s" % n_res_in_middle) + Samples in the middle blob: 801 + >>> print("Middle blob unchanged: %s" % (n_res_in_middle == blobs[1] + 1)) + Middle blob unchanged: True + >>> print("More 0 samples: %s" % ((y_res == 0).sum() > (y == 0).sum())) + More 0 samples: True + + """ + def __init__(self, + sampling_strategy='auto', + random_state=None, + k_neighbors=2, + n_jobs=1, + kmeans_estimator=None, + cluster_balance_threshold="auto", + density_exponent="auto"): + super().__init__( + sampling_strategy=sampling_strategy, random_state=random_state, + k_neighbors=k_neighbors, n_jobs=n_jobs) + self.kmeans_estimator = kmeans_estimator + self.cluster_balance_threshold = cluster_balance_threshold + self.density_exponent = density_exponent + + def _validate_estimator(self): + super()._validate_estimator() + if self.kmeans_estimator is None: + self.kmeans_estimator_ = MiniBatchKMeans( + random_state=self.random_state) + elif isinstance(self.kmeans_estimator, int): + self.kmeans_estimator_ = MiniBatchKMeans( + n_clusters=self.kmeans_estimator, + random_state=self.random_state) + else: + self.kmeans_estimator_ = clone(self.kmeans_estimator) + + # validate the parameters + for param_name in ('cluster_balance_threshold', 'density_exponent'): + param = getattr(self, param_name) + if isinstance(param, str) and param != 'auto': + raise ValueError( + "'{}' should be 'auto' when a string is passed. " + "Got {} instead.".format(param_name, repr(param)) + ) + + self.cluster_balance_threshold_ = ( + self.cluster_balance_threshold + if self.kmeans_estimator_.n_clusters != 1 else -np.inf + ) + + + def _find_cluster_sparsity(self, X): + """Compute the cluster sparsity.""" + euclidean_distances = pairwise_distances(X, metric="euclidean", + n_jobs=self.n_jobs) + # negate diagonal elements + for ind in range(X.shape[0]): + euclidean_distances[ind, ind] = 0 + + non_diag_elements = (X.shape[0] ** 2) - X.shape[0] + mean_distance = euclidean_distances.sum() / non_diag_elements + exponent = (math.log(X.shape[0], 1.6) ** 1.8 * 0.16 + if self.density_exponent == 'auto' + else self.density_exponent) + return (mean_distance ** exponent) / X.shape[0] + + # FIXME: rename _sample -> _fit_resample in 0.6 + def _fit_resample(self, X, y): + return self._sample(X, y) + + def _sample(self, X, y): + self._validate_estimator() + X_resampled = X.copy() + y_resampled = y.copy() + total_inp_samples = sum(self.sampling_strategy_.values()) + + for class_sample, n_samples in self.sampling_strategy_.items(): + if n_samples == 0: + continue + + # target_class_indices = np.flatnonzero(y == class_sample) + # X_class = safe_indexing(X, target_class_indices) + + X_clusters = self.kmeans_estimator_.fit_predict(X) + valid_clusters = [] + cluster_sparsities = [] + + # identify cluster which are answering the requirements + for cluster_idx in range(self.kmeans_estimator_.n_clusters): + + cluster_mask = np.flatnonzero(X_clusters == cluster_idx) + X_cluster = safe_indexing(X, cluster_mask) + y_cluster = safe_indexing(y, cluster_mask) + + cluster_class_mean = (y_cluster == class_sample).mean() + + if self.cluster_balance_threshold_ == "auto": + balance_threshold = n_samples / total_inp_samples / 2 + else: + balance_threshold = self.cluster_balance_threshold_ + + # the cluster is already considered balanced + if cluster_class_mean < balance_threshold: + continue + + # not enough samples to apply SMOTE + anticipated_samples = cluster_class_mean * X_cluster.shape[0] + if anticipated_samples < self.nn_k_.n_neighbors: + continue + + X_cluster_class = safe_indexing( + X_cluster, np.flatnonzero(y_cluster == class_sample) + ) + + valid_clusters.append(cluster_mask) + cluster_sparsities.append( + self._find_cluster_sparsity(X_cluster_class) + ) + + cluster_sparsities = np.array(cluster_sparsities) + cluster_weights = cluster_sparsities / cluster_sparsities.sum() + + if not valid_clusters: + raise RuntimeError( + "No clusters found with sufficient samples of " + "class {}. Try lowering the cluster_balance_threshold or " + "or increasing the number of " + "clusters.".format(class_sample)) + + for valid_cluster_idx, valid_cluster in enumerate(valid_clusters): + X_cluster = safe_indexing(X, valid_cluster) + y_cluster = safe_indexing(y, valid_cluster) + + X_cluster_class = safe_indexing( + X_cluster, np.flatnonzero(y_cluster == class_sample) + ) + + self.nn_k_.fit(X_cluster_class) + nns = self.nn_k_.kneighbors(X_cluster_class, + return_distance=False)[:, 1:] + + cluster_n_samples = int(math.ceil( + n_samples * cluster_weights[valid_cluster_idx]) + ) + + X_new, y_new = self._make_samples(X_cluster_class, + y.dtype, + class_sample, + X_cluster_class, + nns, + cluster_n_samples, + 1.0) + + stack = [np.vstack, sparse.vstack][int(sparse.issparse(X_new))] + X_resampled = stack((X_resampled, X_new)) + y_resampled = np.hstack((y_resampled, y_new)) + + return X_resampled, y_resampled diff --git a/imblearn/over_sampling/tests/test_kmeans_smote.py b/imblearn/over_sampling/tests/test_kmeans_smote.py new file mode 100644 index 000000000..9bb9e9a62 --- /dev/null +++ b/imblearn/over_sampling/tests/test_kmeans_smote.py @@ -0,0 +1,104 @@ +import pytest +import numpy as np + +from sklearn.utils.testing import assert_allclose +from sklearn.utils.testing import assert_array_equal + +from sklearn.cluster import KMeans +from sklearn.cluster import MiniBatchKMeans +from sklearn.neighbors import NearestNeighbors + +from imblearn.over_sampling import KMeansSMOTE +from imblearn.over_sampling import SMOTE + + +@pytest.fixture +def data(): + X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], + [1.25192108, -0.22367336], [0.53366841, -0.30312976], + [1.52091956, -0.49283504], [-0.28162401, -2.10400981], + [0.83680821, 1.72827342], [0.3084254, 0.33299982], + [0.70472253, -0.73309052], [0.28893132, -0.38761769], + [1.15514042, 0.0129463], [0.88407872, 0.35454207], + [1.31301027, -0.92648734], [-1.11515198, -0.93689695], + [-0.18410027, -0.45194484], [0.9281014, 0.53085498], + [-0.14374509, 0.27370049], [-0.41635887, -0.38299653], + [0.08711622, 0.93259929], [1.70580611, -0.11219234]]) + y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) + return X, y + + +def test_kmeans_smote(data): + X, y = data + kmeans_smote = KMeansSMOTE(kmeans_estimator=1, + random_state=42, + cluster_balance_threshold=0.0, + k_neighbors=5) + smote = SMOTE(random_state=42) + + X_res_1, y_res_1 = kmeans_smote.fit_sample(X, y) + X_res_2, y_res_2 = smote.fit_sample(X, y) + + assert_allclose(X_res_1, X_res_2) + assert_array_equal(y_res_1, y_res_2) + + assert kmeans_smote.nn_k_.n_neighbors == 6 + assert kmeans_smote.kmeans_estimator_.n_clusters == 1 + assert 'batch_size' in kmeans_smote.kmeans_estimator_.get_params() + + +@pytest.mark.parametrize("k_neighbors", [2, NearestNeighbors(n_neighbors=3)]) +@pytest.mark.parametrize( + "kmeans_estimator", + [3, + KMeans(n_clusters=3, random_state=42), + MiniBatchKMeans(n_clusters=3, random_state=42)] +) +def test_sample_kmeans_custom(data, k_neighbors, kmeans_estimator): + X, y = data + kmeans_smote = KMeansSMOTE(random_state=42, + kmeans_estimator=kmeans_estimator, + k_neighbors=k_neighbors) + X_resampled, y_resampled = kmeans_smote.fit_sample(X, y) + assert X_resampled.shape == (24, 2) + assert y_resampled.shape == (24,) + + assert kmeans_smote.nn_k_.n_neighbors == 3 + assert kmeans_smote.kmeans_estimator_.n_clusters == 3 + +def test_sample_kmeans_not_enough_clusters(): + rng = np.random.RandomState(42) + X = rng.randn(30, 2) + y = np.array([1] * 20 + [0] * 10) + + smote = KMeansSMOTE(random_state=42, + kmeans_estimator=30, + k_neighbors=2) + with pytest.raises(RuntimeError): + smote.fit_sample(X, y) + + +@pytest.mark.parametrize("density_exponent", ["auto", 2]) +@pytest.mark.parametrize("cluster_balance_threshold", ["auto", 0.8]) +def test_sample_kmeans_density_estimation(data, density_exponent, + cluster_balance_threshold): + X, y = data + smote = KMeansSMOTE(random_state=42, + density_exponent=density_exponent, + cluster_balance_threshold=cluster_balance_threshold) + smote.fit_sample(X, y) + + +@pytest.mark.parametrize( + "density_exponent, cluster_balance_threshold", + [('xxx', 'auto'), ('auto', 'xxx')] +) +def test_kmeans_smote_param_error(data, density_exponent, + cluster_balance_threshold): + X, y = data + kmeans_smote = KMeansSMOTE( + density_exponent=density_exponent, + cluster_balance_threshold=cluster_balance_threshold + ) + with pytest.raises(ValueError, match="should be 'auto' when a string"): + kmeans_smote.fit_resample(X, y) diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index bb0734c9e..c967a005c 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -34,7 +34,7 @@ from imblearn.over_sampling import SMOTE from imblearn.under_sampling import NearMiss, ClusterCentroids -DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE'] +DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE', 'KMeansSMOTE'] # FIXME: remove in 0.6 DONT_HAVE_RANDOM_STATE = ('NearMiss', 'EditedNearestNeighbours', 'RepeatedEditedNearestNeighbours', 'AllKNN', @@ -135,6 +135,7 @@ def check_samplers_one_label(name, Sampler): def check_samplers_fit(name, Sampler): sampler = Sampler() + np.random.seed(42) # Make this test reproducible X = np.random.random((30, 2)) y = np.array([1] * 20 + [0] * 10) sampler.fit_resample(X, y)