diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 0d387cc44..a100750ee 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -127,11 +127,11 @@ nearest neighbors class. Those variants are presented in the figure below. :align: center -The parameter ``kind`` is controlling this feature and the following types are -available: (i) ``'borderline1'``, (ii) ``'borderline2'``, and (iii) ``'svm'``:: +The :class:`BorderlineSMOTE` and :class:`SVMSMOTE` offer some variant of the SMOTE +algorithm:: - >>> from imblearn.over_sampling import SMOTE, ADASYN - >>> X_resampled, y_resampled = SMOTE(kind='borderline1').fit_sample(X, y) + >>> from imblearn.over_sampling import BorderlineSMOTE + >>> X_resampled, y_resampled = BorderlineSMOTE().fit_sample(X, y) >>> print(sorted(Counter(y_resampled).items())) [(0, 4674), (1, 4674), (2, 4674)] @@ -168,12 +168,11 @@ interpolation will create a sample on the line between :math:`x_{i}` and Each SMOTE variant and ADASYN differ from each other by selecting the samples :math:`x_i` ahead of generating the new samples. -The **regular** SMOTE algorithm --- cf. to ``kind='regular'`` when -instantiating a :class:`SMOTE` object --- does not impose any rule and will -randomly pick-up all possible :math:`x_i` available. +The **regular** SMOTE algorithm --- cf. to the :class:`SMOTE` object --- does not +impose any rule and will randomly pick-up all possible :math:`x_i` available. -The **borderline** SMOTE --- cf. to ``kind='borderline1'`` and -``kind='borderline2'`` when instantiating a :class:`SMOTE` object --- will +The **borderline** SMOTE --- cf. to the :class:`BorderlineSMOTE` with the +parameters ``kind='borderline-1'`` and ``kind='borderline-2'`` --- will classify each sample :math:`x_i` to be (i) noise (i.e. all nearest-neighbors are from a different class than the one of :math:`x_i`), (ii) in danger (i.e. at least half of the nearest neighbors are from the same class than @@ -184,10 +183,9 @@ samples *in danger* to generate new samples. In **Borderline-1** SMOTE, :math:`x_i`. On the contrary, **Borderline-2** SMOTE will consider :math:`x_{zi}` which can be from any class. -**SVM** SMOTE --- cf. to ``kind='svm'`` when instantiating a :class:`SMOTE` -object --- uses an SVM classifier to find support vectors and generate samples -considering them. Note that the ``C`` parameter of the SVM classifier allows to -select more or less support vectors. +**SVM** SMOTE --- cf. to :class:`SVMSMOTE` --- uses an SVM classifier to find +support vectors and generate samples considering them. Note that the ``C`` +parameter of the SVM classifier allows to select more or less support vectors. For both borderline and SVM SMOTE, a neighborhood is defined using the parameter ``m_neighbors`` to decide if a sample is in danger, safe, or noise. @@ -196,7 +194,7 @@ ADASYN is working similarly to the regular SMOTE. However, the number of samples generated for each :math:`x_i` is proportional to the number of samples which are not from the same class than :math:`x_i` in a given neighborhood. Therefore, more samples will be generated in the area that the -nearest neighbor rule is not respected. The parameter ``n_neighbors`` is +nearest neighbor rule is not respected. The parameter ``m_neighbors`` is equivalent to ``k_neighbors`` in :class:`SMOTE`. Multi-class management diff --git a/doc/whats_new/v0.0.4.rst b/doc/whats_new/v0.0.4.rst index 6546a57a0..e359a80fe 100644 --- a/doc/whats_new/v0.0.4.rst +++ b/doc/whats_new/v0.0.4.rst @@ -30,6 +30,10 @@ Enhancement - Add support for one-vs-all encoded target to support keras. :issue:`409` by :user:`Guillaume Lemaitre `. +- Adding specific class for borderline and SVM SMOTE using + :class:`BorderlineSMOTE` and :class:`SVMSMOTE`. + :issue:`440` by :user:`Guillaume Lemaitre `. + Bug fixes ......... @@ -63,3 +67,9 @@ Deprecation :class:`imblearn.under_sampling.NeighbourhoodCleaningRule`, :class:`imblearn.under_sampling.InstanceHardnessThreshold`, :class:`imblearn.under_sampling.CondensedNearestNeighbours`. + +- Deprecate ``kind``, ``out_step``, ``svm_estimator``, ``m_neighbors`` in + :class:`imblearn.over_sampling.SMOTE`. User should use + :class:`imblearn.over_sampling.SVMSMOTE` and + :class:`imblearn.over_sampling.BorderlineSMOTE`. + :issue:`440` by :user:`Guillaume Lemaitre `. diff --git a/examples/over-sampling/plot_comparison_over_sampling.py b/examples/over-sampling/plot_comparison_over_sampling.py index ba22f7bbd..41d395594 100644 --- a/examples/over-sampling/plot_comparison_over_sampling.py +++ b/examples/over-sampling/plot_comparison_over_sampling.py @@ -20,7 +20,9 @@ from sklearn.svm import LinearSVC from imblearn.pipeline import make_pipeline -from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler +from imblearn.over_sampling import ADASYN +from imblearn.over_sampling import SMOTE, BorderlineSMOTE, SVMSMOTE +from imblearn.over_sampling import RandomOverSampler from imblearn.base import SamplerMixin from imblearn.utils import hash_X_y @@ -220,21 +222,18 @@ def fit_sample(self, X, y): class_sep=0.8) ax_arr = ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8)) -string_add = ['regular', 'borderline-1', 'borderline-2', 'SVM'] -for str_add, ax, sampler in zip(string_add, - ax_arr, - (SMOTE(random_state=0), - SMOTE(random_state=0, kind='borderline1'), - SMOTE(random_state=0, kind='borderline2'), - SMOTE(random_state=0, kind='svm'))): +for ax, sampler in zip(ax_arr, + (SMOTE(random_state=0), + BorderlineSMOTE(random_state=0, kind='borderline-1'), + BorderlineSMOTE(random_state=0, kind='borderline-2'), + SVMSMOTE(random_state=0))): clf = make_pipeline(sampler, LinearSVC()) clf.fit(X, y) plot_decision_function(X, y, clf, ax[0]) - ax[0].set_title('Decision function for {} {}'.format( - str_add, sampler.__class__.__name__)) + ax[0].set_title('Decision function for {}'.format( + sampler.__class__.__name__)) plot_resampling(X, y, sampler, ax[1]) - ax[1].set_title('Resampling using {} {}'.format( - str_add, sampler.__class__.__name__)) + ax[1].set_title('Resampling using {}'.format(sampler.__class__.__name__)) fig.tight_layout() plt.show() diff --git a/examples/over-sampling/plot_smote.py b/examples/over-sampling/plot_smote.py index b4fe22d3e..591720c2e 100644 --- a/examples/over-sampling/plot_smote.py +++ b/examples/over-sampling/plot_smote.py @@ -17,6 +17,8 @@ from sklearn.decomposition import PCA from imblearn.over_sampling import SMOTE +from imblearn.over_sampling import BorderlineSMOTE +from imblearn.over_sampling import SVMSMOTE print(__doc__) @@ -49,8 +51,8 @@ def plot_resampling(ax, X, y, title): X_vis = pca.fit_transform(X) # Apply regular SMOTE -kind = ['regular', 'borderline1', 'borderline2', 'svm'] -sm = [SMOTE(kind=k) for k in kind] +sm = [SMOTE(), BorderlineSMOTE(kind='borderline-1'), + BorderlineSMOTE(kind='borderline-2'), SVMSMOTE()] X_resampled = [] y_resampled = [] X_res_vis = [] @@ -67,9 +69,10 @@ def plot_resampling(ax, X, y, title): ax_res = [ax3, ax4, ax5, ax6] c0, c1 = plot_resampling(ax1, X_vis, y, 'Original set') -for i in range(len(kind)): +for i, name in enumerate(['SMOTE', 'SMOTE Borderline-1', + 'SMOTE Borderline-2', 'SMOTE SVM']): plot_resampling(ax_res[i], X_res_vis[i], y_resampled[i], - 'SMOTE {}'.format(kind[i])) + '{}'.format(name)) ax2.legend((c0, c1), ('Class #0', 'Class #1'), loc='center', ncol=1, labelspacing=0.) diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index 3d92ef0a5..abfec8b80 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -6,5 +6,8 @@ from .adasyn import ADASYN from .random_over_sampler import RandomOverSampler from .smote import SMOTE +from .smote import BorderlineSMOTE +from .smote import SVMSMOTE -__all__ = ['ADASYN', 'RandomOverSampler', 'SMOTE'] +__all__ = ['ADASYN', 'RandomOverSampler', + 'SMOTE', 'BorderlineSMOTE', 'SVMSMOTE'] diff --git a/imblearn/over_sampling/smote.py b/imblearn/over_sampling/smote.py index 8e8b7965e..a6625b200 100644 --- a/imblearn/over_sampling/smote.py +++ b/imblearn/over_sampling/smote.py @@ -7,6 +7,9 @@ from __future__ import division +import types +import warnings + import numpy as np from scipy import sparse @@ -20,170 +23,31 @@ from ..utils import Substitution from ..utils._docstring import _random_state_docstring +# FIXME: remove in 0.6 SMOTE_KIND = ('regular', 'borderline1', 'borderline2', 'svm') -@Substitution( - sampling_strategy=BaseOverSampler._sampling_strategy_docstring, - random_state=_random_state_docstring) -class SMOTE(BaseOverSampler): - """Class to perform over-sampling using SMOTE. - - This object is an implementation of SMOTE - Synthetic Minority - Over-sampling Technique, and the variants Borderline SMOTE 1, 2 and - SVM-SMOTE. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - {sampling_strategy} - - {random_state} - - k_neighbors : int or object, optional (default=5) - If ``int``, number of nearest neighbours to used to construct synthetic - samples. If object, an estimator that inherits from - :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the k_neighbors. - - m_neighbors : int or object, optional (default=10) - If int, number of nearest neighbours to use to determine if a minority - sample is in danger. Used with ``kind={{'borderline1', 'borderline2', - 'svm'}}``. If object, an estimator that inherits - from :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used - to find the k_neighbors. - - out_step : float, optional (default=0.5) - Step size when extrapolating. Used with ``kind='svm'``. - - kind : str, optional (default='regular') - The type of SMOTE algorithm to use one of the following options: - ``'regular'``, ``'borderline1'``, ``'borderline2'``, ``'svm'``. - - svm_estimator : object, optional (default=SVC()) - If ``kind='svm'``, a parametrized :class:`sklearn.svm.SVC` - classifier can be passed. - - n_jobs : int, optional (default=1) - The number of threads to open if possible. - - ratio : str, dict, or callable - .. deprecated:: 0.4 - Use the parameter ``sampling_strategy`` instead. It will be removed - in 0.6. - - Notes - ----- - See the original papers: [1]_, [2]_, [3]_ for more details. - - Supports multi-class resampling. A one-vs.-rest scheme is used as - originally proposed in [1]_. - - See - :ref:`sphx_glr_auto_examples_applications_plot_over_sampling_benchmark_lfw.py`, - :ref:`sphx_glr_auto_examples_evaluation_plot_classification_report.py`, - :ref:`sphx_glr_auto_examples_evaluation_plot_metrics.py`, - :ref:`sphx_glr_auto_examples_model_selection_plot_validation_curve.py`, - :ref:`sphx_glr_auto_examples_over-sampling_plot_comparison_over_sampling.py`, - and :ref:`sphx_glr_auto_examples_over-sampling_plot_smote.py`. - - See also - -------- - ADASYN : Over-sample using ADASYN. - - References - ---------- - .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: - synthetic minority over-sampling technique," Journal of artificial - intelligence research, 321-357, 2002. - - .. [2] H. Han, W. Wen-Yuan, M. Bing-Huan, "Borderline-SMOTE: a new - over-sampling method in imbalanced data sets learning," Advances in - intelligent computing, 878-887, 2005. - - .. [3] H. M. Nguyen, E. W. Cooper, K. Kamei, "Borderline over-sampling for - imbalanced data classification," International Journal of Knowledge - Engineering and Soft Data Paradigms, 3(1), pp.4-21, 2001. - - Examples - -------- - - >>> from collections import Counter - >>> from sklearn.datasets import make_classification - >>> from imblearn.over_sampling import \ -SMOTE # doctest: +NORMALIZE_WHITESPACE - >>> X, y = make_classification(n_classes=2, class_sep=2, - ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, - ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) - >>> print('Original dataset shape %s' % Counter(y)) - Original dataset shape Counter({{1: 900, 0: 100}}) - >>> sm = SMOTE(random_state=42) - >>> X_res, y_res = sm.fit_sample(X, y) - >>> print('Resampled dataset shape %s' % Counter(y_res)) - Resampled dataset shape Counter({{0: 900, 1: 900}}) - - """ - +class BaseSMOTE(BaseOverSampler): + """Base class for the different SMOTE algorithms.""" def __init__(self, sampling_strategy='auto', random_state=None, k_neighbors=5, - m_neighbors=10, - out_step=0.5, - kind='regular', - svm_estimator=None, n_jobs=1, ratio=None): - super(SMOTE, self).__init__( + super(BaseSMOTE, self).__init__( sampling_strategy=sampling_strategy, ratio=ratio) self.random_state = random_state - self.kind = kind self.k_neighbors = k_neighbors - self.m_neighbors = m_neighbors - self.out_step = out_step - self.svm_estimator = svm_estimator self.n_jobs = n_jobs - def _in_danger_noise(self, samples, target_class, y, kind='danger'): - """Estimate if a set of sample are in danger or noise. - - Parameters - ---------- - samples : {array-like, sparse matrix}, shape (n_samples, n_features) - The samples to check if either they are in danger or not. - - target_class : int or str, - The target corresponding class being over-sampled. - - y : array-like, shape (n_samples,) - The true label in order to check the neighbour labels. - - kind : str, optional (default='danger') - The type of classification to use. Can be either: - - - If 'danger', check if samples are in danger, - - If 'noise', check if samples are noise. - - Returns - ------- - output : ndarray, shape (n_samples,) - A boolean array where True refer to samples in danger or noise. - + def _validate_estimator(self): + """Check the NN estimators shared across the different SMOTE + algorithms. """ - x = self.nn_m_.kneighbors(samples, return_distance=False)[:, 1:] - nn_label = (y[x] != target_class).astype(int) - n_maj = np.sum(nn_label, axis=1) - - if kind == 'danger': - # Samples are in danger for m/2 <= m' < m - return np.bitwise_and(n_maj >= (self.nn_m_.n_neighbors - 1) / 2, - n_maj < self.nn_m_.n_neighbors - 1) - elif kind == 'noise': - # Samples are noise for m = m' - return n_maj == self.nn_m_.n_neighbors - 1 - else: - raise NotImplementedError + self.nn_k_ = check_neighbors_object( + 'k_neighbors', self.k_neighbors, additional_neighbor=1) + self.nn_k_.set_params(**{'n_jobs': self.n_jobs}) def _make_samples(self, X, @@ -256,116 +120,162 @@ def _make_samples(self, else: return X_new, y_new - def _validate_estimator(self): - """Create the necessary objects for SMOTE.""" - - if self.kind not in SMOTE_KIND: - raise ValueError('Unknown kind for SMOTE algorithm.' - ' Choices are {}. Got {} instead.'.format( - SMOTE_KIND, self.kind)) + def _in_danger_noise(self, nn_estimator, samples, target_class, y, + kind='danger'): + """Estimate if a set of sample are in danger or noise. - self.nn_k_ = check_neighbors_object( - 'k_neighbors', self.k_neighbors, additional_neighbor=1) - self.nn_k_.set_params(**{'n_jobs': self.n_jobs}) + Used by BorderlineSMOTE and SVMSMOTE. - if self.kind != 'regular': - self.nn_m_ = check_neighbors_object( - 'm_neighbors', self.m_neighbors, additional_neighbor=1) - self.nn_m_.set_params(**{'n_jobs': self.n_jobs}) + Parameters + ---------- + nn_estimator : estimator + An estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` use to determine if + a sample is in danger/noise. - if self.kind == 'svm': - if self.svm_estimator is None: - self.svm_estimator_ = SVC(random_state=self.random_state) - elif isinstance(self.svm_estimator, SVC): - self.svm_estimator_ = self.svm_estimator - else: - raise_isinstance_error('svm_estimator', [SVC], - self.svm_estimator) + samples : {array-like, sparse matrix}, shape (n_samples, n_features) + The samples to check if either they are in danger or not. - def _sample_regular(self, X, y): - """Resample the dataset using the regular SMOTE implementation. + target_class : int or str + The target corresponding class being over-sampled. - Use the regular SMOTE algorithm proposed in [1]_. + y : array-like, shape (n_samples,) + The true label in order to check the neighbour labels. - Parameters - ---------- - X : {array-like, sparse matrix}, shape (n_samples, n_features) - Matrix containing the data which have to be sampled. + kind : str, optional (default='danger') + The type of classification to use. Can be either: - y : array-like, shape (n_samples,) - Corresponding label for each sample in X. + - If 'danger', check if samples are in danger, + - If 'noise', check if samples are noise. Returns ------- - X_resampled : {ndarray, sparse matrix}, shape \ -(n_samples_new, n_features) - The array containing the resampled data. + output : ndarray, shape (n_samples,) + A boolean array where True refer to samples in danger or noise. - y_resampled : ndarray, shape (n_samples_new,) - The corresponding label of `X_resampled` + """ + x = nn_estimator.kneighbors(samples, return_distance=False)[:, 1:] + nn_label = (y[x] != target_class).astype(int) + n_maj = np.sum(nn_label, axis=1) - References - ---------- - .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: - synthetic minority over-sampling technique," Journal of artificial - intelligence research, 321-357, 2002. + if kind == 'danger': + # Samples are in danger for m/2 <= m' < m + return np.bitwise_and(n_maj >= (nn_estimator.n_neighbors - 1) / 2, + n_maj < nn_estimator.n_neighbors - 1) + elif kind == 'noise': + # Samples are noise for m = m' + return n_maj == nn_estimator.n_neighbors - 1 + else: + raise NotImplementedError - """ - X_resampled = X.copy() - y_resampled = y.copy() +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + random_state=_random_state_docstring) +class BorderlineSMOTE(BaseSMOTE): + """Over-sampling using Borderline SMOTE. - for class_sample, n_samples in self.sampling_strategy_.items(): - if n_samples == 0: - continue - target_class_indices = np.flatnonzero(y == class_sample) - X_class = safe_indexing(X, target_class_indices) + This algorithm is a variant of the original SMOTE algorithm proposed in + [2]_. Borderline samples will be detected and used to generate new + synthetic samples. - self.nn_k_.fit(X_class) - nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:] - X_new, y_new = self._make_samples(X_class, class_sample, X_class, - nns, n_samples, 1.0) + Read more in the :ref:`User Guide `. - if sparse.issparse(X_new): - X_resampled = sparse.vstack([X_resampled, X_new]) - else: - X_resampled = np.vstack((X_resampled, X_new)) - y_resampled = np.hstack((y_resampled, y_new)) + Parameters + ---------- + {sampling_strategy} - return X_resampled, y_resampled + {random_state} - def _sample_borderline(self, X, y): - """Resample the dataset using the borderline SMOTE implementation. + k_neighbors : int or object, optional (default=5) + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. - Use the borderline SMOTE algorithm proposed in [2]_. Two methods can be - used: (i) borderline-1 or (ii) borderline-2. A nearest-neighbours - algorithm is used to determine the samples forming the boundaries and - will create samples next to those features depending on some criterion. + n_jobs : int, optional (default=1) + The number of threads to open if possible. - Parameters - ---------- - X : {array-like, sparse matrix}, shape (n_samples, n_features) - Matrix containing the data which have to be sampled. + m_neighbors : int or object, optional (default=10) + If int, number of nearest neighbours to use to determine if a minority + sample is in danger. If object, an estimator that inherits + from :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used + to find the m_neighbors. - y : array-like, shape (n_samples,) - Corresponding label for each sample in X. + kind : str, optional (default='borderline-1') + The type of SMOTE algorithm to use one of the following options: + ``'borderline-1'``, ``'borderline-2'``. - Returns - ------- - X_resampled : {ndarray, sparse matrix}, shape \ -(n_samples_new, n_features) - The array containing the resampled data. + Notes + ----- + See the original papers: [2]_ for more details. - y_resampled : ndarray, shape (n_samples_new,) - The corresponding label of `X_resampled` + Supports multi-class resampling. A one-vs.-rest scheme is used as + originally proposed in [1]_. - References - ---------- - .. [2] H. Han, W. Wen-Yuan, M. Bing-Huan, "Borderline-SMOTE: a new - over-sampling method in imbalanced data sets learning," Advances in - intelligent computing, 878-887, 2005. + See also + -------- + SMOTE : Over-sample using SMOTE. + + SVMSMOTE : Over-sample using SVM-SMOTE variant. + + ADASYN : Over-sample using ADASYN. + + References + ---------- + .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique," Journal of artificial + intelligence research, 321-357, 2002. + + .. [2] H. Han, W. Wen-Yuan, M. Bing-Huan, "Borderline-SMOTE: a new + over-sampling method in imbalanced data sets learning," Advances in + intelligent computing, 878-887, 2005. + + Examples + -------- + + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from imblearn.over_sampling import \ +BorderlineSMOTE # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({{1: 900, 0: 100}}) + >>> sm = BorderlineSMOTE(random_state=42) + >>> X_res, y_res = sm.fit_sample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({{0: 900, 1: 900}}) + + """ + + def __init__(self, + sampling_strategy='auto', + random_state=None, + k_neighbors=5, + n_jobs=1, + m_neighbors=10, + kind='borderline-1'): + super(BorderlineSMOTE, self).__init__( + sampling_strategy=sampling_strategy, random_state=random_state, + k_neighbors=k_neighbors, n_jobs=n_jobs, ratio=None) + self.m_neighbors = m_neighbors + self.kind = kind + + def _validate_estimator(self): + super(BorderlineSMOTE, self)._validate_estimator() + self.nn_m_ = check_neighbors_object( + 'k_neighbors', self.k_neighbors, additional_neighbor=1) + self.nn_m_.set_params(**{'n_jobs': self.n_jobs}) + if self.kind not in ('borderline-1', 'borderline-2'): + raise ValueError('The possible "kind" of algorithm are ' + '"borderline-1" and "borderline-2".' + 'Got {} instead.'.format(self.kind)) + + def _sample(self, X, y): + self._validate_estimator() - """ X_resampled = X.copy() y_resampled = y.copy() @@ -377,7 +287,7 @@ def _sample_borderline(self, X, y): self.nn_m_.fit(X) danger_index = self._in_danger_noise( - X_class, class_sample, y, kind='danger') + self.nn_m_, X_class, class_sample, y, kind='danger') if not any(danger_index): continue @@ -387,7 +297,7 @@ def _sample_borderline(self, X, y): return_distance=False)[:, 1:] # divergence between borderline-1 and borderline-2 - if self.kind == 'borderline1': + if self.kind == 'borderline-1': # Create synthetic samples for borderline points. X_new, y_new = self._make_samples( safe_indexing(X_class, danger_index), class_sample, @@ -398,7 +308,7 @@ def _sample_borderline(self, X, y): X_resampled = np.vstack((X_resampled, X_new)) y_resampled = np.hstack((y_resampled, y_new)) - else: + elif self.kind == 'borderline-2': random_state = check_random_state(self.random_state) fractions = random_state.beta(10, 10) @@ -431,36 +341,120 @@ def _sample_borderline(self, X, y): return X_resampled, y_resampled - def _sample_svm(self, X, y): - """Resample the dataset using the SVM SMOTE implementation. - Use the SVM SMOTE algorithm proposed in [3]_. A SVM classifier detect - support vectors to get a notion of the boundary. +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + random_state=_random_state_docstring) +class SVMSMOTE(BaseSMOTE): + """Over-sampling using SVM-SMOTE. - Parameters - ---------- - X : {array-like, sparse matrix}, shape (n_samples, n_features) - Matrix containing the data which have to be sampled. + Variant of SMOTE algorithm which use an SVM algorithm to detect sample to + use for generating new synthetic samples as proposed in [2]_. - y : array-like, shape (n_samples,) - Corresponding label for each sample in X. + Read more in the :ref:`User Guide `. - Returns - ------- - X_resampled : {ndarray, sparse matrix}, shape \ -(n_samples_new, n_features) - The array containing the resampled data. + Parameters + ---------- + {sampling_strategy} - y_resampled : ndarray, shape (n_samples_new,) - The corresponding label of `X_resampled` + {random_state} - References - ---------- - .. [3] H. M. Nguyen, E. W. Cooper, K. Kamei, "Borderline over-sampling - for imbalanced data classification," International Journal of - Knowledge Engineering and Soft Data Paradigms, 3(1), pp.4-21, 2001. + k_neighbors : int or object, optional (default=5) + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. - """ + n_jobs : int, optional (default=1) + The number of threads to open if possible. + + m_neighbors : int or object, optional (default=10) + If int, number of nearest neighbours to use to determine if a minority + sample is in danger. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the m_neighbors. + + svm_estimator : object, optional (default=SVC()) + A parametrized :class:`sklearn.svm.SVC` classifier can be passed. + + out_step : float, optional (default=0.5) + Step size when extrapolating. + + Notes + ----- + See the original papers: [2]_ for more details. + + Supports multi-class resampling. A one-vs.-rest scheme is used as + originally proposed in [1]_. + + See also + -------- + SMOTE : Over-sample using SMOTE. + + BorderlineSMOTE : Over-sample using Borderline-SMOTE. + + ADASYN : Over-sample using ADASYN. + + References + ---------- + .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique," Journal of artificial + intelligence research, 321-357, 2002. + + .. [2] H. M. Nguyen, E. W. Cooper, K. Kamei, "Borderline over-sampling for + imbalanced data classification," International Journal of Knowledge + Engineering and Soft Data Paradigms, 3(1), pp.4-21, 2009. + + Examples + -------- + + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from imblearn.over_sampling import \ +SVMSMOTE # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({{1: 900, 0: 100}}) + >>> sm = SVMSMOTE(random_state=42) + >>> X_res, y_res = sm.fit_sample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({{0: 900, 1: 900}}) + + """ + + def __init__(self, + sampling_strategy='auto', + random_state=None, + k_neighbors=5, + n_jobs=1, + m_neighbors=10, + svm_estimator=None, + out_step=0.5): + super(SVMSMOTE, self).__init__( + sampling_strategy=sampling_strategy, random_state=random_state, + k_neighbors=k_neighbors, n_jobs=n_jobs, ratio=None) + self.m_neighbors = m_neighbors + self.svm_estimator = svm_estimator + self.out_step = out_step + + def _validate_estimator(self): + super(SVMSMOTE, self)._validate_estimator() + self.nn_m_ = check_neighbors_object( + 'k_neighbors', self.k_neighbors, additional_neighbor=1) + self.nn_m_.set_params(**{'n_jobs': self.n_jobs}) + + if self.svm_estimator is None: + self.svm_estimator_ = SVC(random_state=self.random_state) + elif isinstance(self.svm_estimator, SVC): + self.svm_estimator_ = self.svm_estimator + else: + raise_isinstance_error('svm_estimator', [SVC], + self.svm_estimator) + + def _sample(self, X, y): + self._validate_estimator() random_state = check_random_state(self.random_state) X_resampled = X.copy() y_resampled = y.copy() @@ -478,15 +472,16 @@ def _sample_svm(self, X, y): self.nn_m_.fit(X) noise_bool = self._in_danger_noise( - support_vector, class_sample, y, kind='noise') + self.nn_m_, support_vector, class_sample, y, kind='noise') support_vector = safe_indexing( support_vector, np.flatnonzero(np.logical_not(noise_bool))) danger_bool = self._in_danger_noise( - support_vector, class_sample, y, kind='danger') + self.nn_m_, support_vector, class_sample, y, kind='danger') safety_bool = np.logical_not(danger_bool) self.nn_k_.fit(X_class) fractions = random_state.beta(10, 10) + n_generated_samples = int(fractions * (n_samples + 1)) if np.count_nonzero(danger_bool) > 0: nns = self.nn_k_.kneighbors( safe_indexing(support_vector, np.flatnonzero(danger_bool)), @@ -497,7 +492,7 @@ def _sample_svm(self, X, y): class_sample, X_class, nns, - int(fractions * (n_samples + 1)), + n_generated_samples, step_size=1.) if np.count_nonzero(safety_bool) > 0: @@ -510,7 +505,7 @@ def _sample_svm(self, X, y): class_sample, X_class, nns, - int((1 - fractions) * n_samples), + n_samples - n_generated_samples, step_size=-self.out_step) if (np.count_nonzero(danger_bool) > 0 and @@ -537,32 +532,218 @@ def _sample_svm(self, X, y): return X_resampled, y_resampled - def _sample(self, X, y): - """Resample the dataset. - Parameters - ---------- - X : {array-like, sparse matrix}, shape (n_samples, n_features) - Matrix containing the data which have to be sampled. +# FIXME: In 0.6, SMOTE should inherit only from BaseSMOTE. +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + random_state=_random_state_docstring) +class SMOTE(SVMSMOTE, BorderlineSMOTE): + """Class to perform over-sampling using SMOTE. - y : array-like, shape (n_samples,) - Corresponding label for each sample in X. + This object is an implementation of SMOTE - Synthetic Minority + Over-sampling Technique as presented in [1]_. - Returns - ------- - X_resampled : {ndarray, sparse matrix}, shape \ -(n_samples_new, n_features) - The array containing the resampled data. + Read more in the :ref:`User Guide `. - y_resampled : ndarray, shape (n_samples_new,) - The corresponding label of `X_resampled` + Parameters + ---------- + {sampling_strategy} - """ + {random_state} + + k_neighbors : int or object, optional (default=5) + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. + + m_neighbors : int or object, optional (default=10) + If int, number of nearest neighbours to use to determine if a minority + sample is in danger. Used with ``kind={{'borderline1', 'borderline2', + 'svm'}}``. If object, an estimator that inherits + from :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used + to find the k_neighbors. + + .. deprecated:: 0.4 + ``m_neighbors`` is deprecated in 0.4 and will be removed in 0.6. Use + :class:`BorderlineSMOTE` or :class:`SVMSMOTE` instead to use the + intended algorithm. + + out_step : float, optional (default=0.5) + Step size when extrapolating. Used with ``kind='svm'``. + + .. deprecated:: 0.4 + ``out_step`` is deprecated in 0.4 and will be removed in 0.6. Use + :class:`SVMSMOTE` instead to use the intended algorithm. + + kind : str, optional (default='regular') + The type of SMOTE algorithm to use one of the following options: + ``'regular'``, ``'borderline1'``, ``'borderline2'``, ``'svm'``. + + .. deprecated:: 0.4 + ``kind`` is deprecated in 0.4 and will be removed in 0.6. Use + :class:`BorderlineSMOTE` or :class:`SVMSMOTE` instead to use the + intended algorithm. + + svm_estimator : object, optional (default=SVC()) + If ``kind='svm'``, a parametrized :class:`sklearn.svm.SVC` + classifier can be passed. + + .. deprecated:: 0.4 + ``out_step`` is deprecated in 0.4 and will be removed in 0.6. Use + :class:`SVMSMOTE` instead to use the intended algorithm. + + n_jobs : int, optional (default=1) + The number of threads to open if possible. + + ratio : str, dict, or callable + .. deprecated:: 0.4 + Use the parameter ``sampling_strategy`` instead. It will be removed + in 0.6. + + Notes + ----- + See the original papers: [1]_ for more details. + + Supports multi-class resampling. A one-vs.-rest scheme is used as + originally proposed in [1]_. + + See also + -------- + BorderlineSMOTE : Over-sample using the borderline-SMOTE variant. + + SVMSMOTE : Over-sample using the SVM-SMOTE variant. + + ADASYN : Over-sample using ADASYN. + + References + ---------- + .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique," Journal of artificial + intelligence research, 321-357, 2002. + + Examples + -------- + + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from imblearn.over_sampling import \ +SMOTE # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({{1: 900, 0: 100}}) + >>> sm = SMOTE(random_state=42) + >>> X_res, y_res = sm.fit_sample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({{0: 900, 1: 900}}) + + """ + def __init__(self, + sampling_strategy='auto', + random_state=None, + k_neighbors=5, + m_neighbors='deprecated', + out_step='deprecated', + kind='deprecated', + svm_estimator='deprecated', + n_jobs=1, + ratio=None): + # FIXME: in 0.6 call super() + BaseSMOTE.__init__(self, sampling_strategy=sampling_strategy, + random_state=random_state, k_neighbors=k_neighbors, + n_jobs=n_jobs, ratio=ratio) + self.kind = kind + self.m_neighbors = m_neighbors + self.out_step = out_step + self.svm_estimator = svm_estimator + self.n_jobs = n_jobs + + def _validate_estimator(self): + # FIXME: in 0.6 call super() + BaseSMOTE._validate_estimator(self) + # FIXME: remove in 0.6 after deprecation cycle + if self.kind != 'deprecated' and not (self.kind == 'borderline-1' or + self.kind == 'borderline-2'): + if self.kind not in SMOTE_KIND: + raise ValueError('Unknown kind for SMOTE algorithm.' + ' Choices are {}. Got {} instead.'.format( + SMOTE_KIND, self.kind)) + else: + warnings.warn('"kind" is deprecated in 0.4 and will be ' + 'removed in 0.6. Use SMOTE, BorderlineSMOTE or ' + 'SVMSMOTE instead.', DeprecationWarning) + + if self.kind == 'borderline1' or self.kind == 'borderline2': + self._sample = types.MethodType(BorderlineSMOTE._sample, self) + self.kind = ('borderline-1' if self.kind == 'borderline1' + else 'borderline-2') + + elif self.kind == 'svm': + self._sample = types.MethodType(SVMSMOTE._sample, self) + + if self.out_step == 'deprecated': + self.out_step = 0.5 + else: + warnings.warn('"out_step" is deprecated in 0.4 and will ' + 'be removed in 0.6. Use SVMSMOTE class ' + 'instead.', DeprecationWarning) + + if self.svm_estimator == 'deprecated': + warnings.warn('"svm_estimator" is deprecated in 0.4 and ' + 'will be removed in 0.6. Use SVMSMOTE class ' + 'instead.', DeprecationWarning) + if (self.svm_estimator is None or + self.svm_estimator == 'deprecated'): + self.svm_estimator_ = SVC(random_state=self.random_state) + elif isinstance(self.svm_estimator, SVC): + self.svm_estimator_ = self.svm_estimator + else: + raise_isinstance_error('svm_estimator', [SVC], + self.svm_estimator) + + if self.kind != 'regular': + if self.m_neighbors == 'deprecated': + self.m_neighbors = 10 + else: + warnings.warn('"m_neighbors" is deprecated in 0.4 and ' + 'will be removed in 0.6. Use SVMSMOTE class ' + 'or BorderlineSMOTE instead.', + DeprecationWarning) + + self.nn_m_ = check_neighbors_object( + 'm_neighbors', self.m_neighbors, additional_neighbor=1) + self.nn_m_.set_params(**{'n_jobs': self.n_jobs}) + + # FIXME: to be removed in 0.6 + def fit(self, X, y): self._validate_estimator() + BaseSMOTE.fit(self, X, y) + return self + + def _sample(self, X, y): + # FIXME: uncomment in version 0.6 + # self._validate_estimator() + + X_resampled = X.copy() + y_resampled = y.copy() + + for class_sample, n_samples in self.sampling_strategy_.items(): + if n_samples == 0: + continue + target_class_indices = np.flatnonzero(y == class_sample) + X_class = safe_indexing(X, target_class_indices) + + self.nn_k_.fit(X_class) + nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:] + X_new, y_new = self._make_samples(X_class, class_sample, X_class, + nns, n_samples, 1.0) - if self.kind == 'regular': - return self._sample_regular(X, y) - elif self.kind == 'borderline1' or self.kind == 'borderline2': - return self._sample_borderline(X, y) - elif self.kind == 'svm': - return self._sample_svm(X, y) + if sparse.issparse(X_new): + X_resampled = sparse.vstack([X_resampled, X_new]) + else: + X_resampled = np.vstack((X_resampled, X_new)) + y_resampled = np.hstack((y_resampled, y_new)) + + return X_resampled, y_resampled diff --git a/imblearn/over_sampling/tests/test_smote.py b/imblearn/over_sampling/tests/test_smote.py index 5346c39fd..5e5a22800 100644 --- a/imblearn/over_sampling/tests/test_smote.py +++ b/imblearn/over_sampling/tests/test_smote.py @@ -6,28 +6,27 @@ from __future__ import print_function import numpy as np -from pytest import raises +import pytest from sklearn.utils.testing import assert_allclose, assert_array_equal from sklearn.neighbors import NearestNeighbors from sklearn.svm import SVC from imblearn.over_sampling import SMOTE +from imblearn.over_sampling import BorderlineSMOTE +from imblearn.over_sampling import SVMSMOTE RND_SEED = 0 -X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [ - 1.25192108, -0.22367336 -], [0.53366841, -0.30312976], [1.52091956, - -0.49283504], [-0.28162401, -2.10400981], - [0.83680821, - 1.72827342], [0.3084254, 0.33299982], [0.70472253, -0.73309052], - [0.28893132, -0.38761769], [1.15514042, 0.0129463], [ - 0.88407872, 0.35454207 - ], [1.31301027, -0.92648734], [-1.11515198, -0.93689695], [ - -0.18410027, -0.45194484 - ], [0.9281014, 0.53085498], [-0.14374509, 0.27370049], [ - -0.41635887, -0.38299653 - ], [0.08711622, 0.93259929], [1.70580611, -0.11219234]]) +X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], + [1.25192108, -0.22367336], [0.53366841, -0.30312976], + [1.52091956, -0.49283504], [-0.28162401, -2.10400981], + [0.83680821, 1.72827342], [0.3084254, 0.33299982], + [0.70472253, -0.73309052], [0.28893132, -0.38761769], + [1.15514042, 0.0129463], [0.88407872, 0.35454207], + [1.31301027, -0.92648734], [-1.11515198, -0.93689695], + [-0.18410027, -0.45194484], [0.9281014, 0.53085498], + [-0.14374509, 0.27370049], [-0.41635887, -0.38299653], + [0.08711622, 0.93259929], [1.70580611, -0.11219234]]) Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) R_TOL = 1e-4 @@ -35,13 +34,12 @@ def test_smote_wrong_kind(): kind = 'rnd' smote = SMOTE(kind=kind, random_state=RND_SEED) - with raises(ValueError, match="Unknown kind for SMOTE"): + with pytest.raises(ValueError, match="Unknown kind for SMOTE"): smote.fit_sample(X, Y) def test_sample_regular(): - kind = 'regular' - smote = SMOTE(random_state=RND_SEED, kind=kind) + smote = SMOTE(random_state=RND_SEED) X_resampled, y_resampled = smote.fit_sample(X, Y) X_gt = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [ 1.25192108, -0.22367336 @@ -67,9 +65,8 @@ def test_sample_regular(): def test_sample_regular_half(): sampling_strategy = {0: 9, 1: 12} - kind = 'regular' smote = SMOTE( - sampling_strategy=sampling_strategy, random_state=RND_SEED, kind=kind) + sampling_strategy=sampling_strategy, random_state=RND_SEED) X_resampled, y_resampled = smote.fit_sample(X, Y) X_gt = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [ 1.25192108, -0.22367336 @@ -90,6 +87,7 @@ def test_sample_regular_half(): assert_array_equal(y_resampled, y_gt) +@pytest.mark.filterwarnings('ignore:"kind" is deprecated in 0.4 and will be') def test_sample_borderline1(): kind = 'borderline1' smote = SMOTE(random_state=RND_SEED, kind=kind) @@ -116,6 +114,7 @@ def test_sample_borderline1(): assert_array_equal(y_resampled, y_gt) +@pytest.mark.filterwarnings('ignore:"kind" is deprecated in 0.4 and will be') def test_sample_borderline2(): kind = 'borderline2' smote = SMOTE(random_state=RND_SEED, kind=kind) @@ -140,30 +139,46 @@ def test_sample_borderline2(): assert_array_equal(y_resampled, y_gt) +@pytest.mark.filterwarnings('ignore:"kind" is deprecated in 0.4 and will be') +@pytest.mark.filterwarnings('ignore:"svm_estimator" is deprecated in 0.4 and') +@pytest.mark.filterwarnings('ignore:"out_step" is deprecated in 0.4 and') +@pytest.mark.filterwarnings('ignore:"m_neighbors" is deprecated in 0.4 and') def test_sample_svm(): kind = 'svm' smote = SMOTE(random_state=RND_SEED, kind=kind) X_resampled, y_resampled = smote.fit_sample(X, Y) - X_gt = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [ - 1.25192108, -0.22367336 - ], [0.53366841, -0.30312976], [1.52091956, -0.49283504], [ - -0.28162401, -2.10400981 - ], [0.83680821, 1.72827342], [0.3084254, 0.33299982], [ - 0.70472253, -0.73309052 - ], [0.28893132, -0.38761769], [1.15514042, 0.0129463], [ - 0.88407872, 0.35454207 - ], [1.31301027, -0.92648734], [-1.11515198, -0.93689695], [ - -0.18410027, -0.45194484 - ], [0.9281014, 0.53085498], [-0.14374509, 0.27370049], - [-0.41635887, -0.38299653], [0.08711622, 0.93259929], - [1.70580611, -0.11219234], [0.47436888, -0.2645749], - [1.07844561, -0.19435291], [1.44015515, -1.30621303]]) - y_gt = np.array( - [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]) + X_gt = np.array([[0.11622591, -0.0317206], + [0.77481731, 0.60935141], + [1.25192108, -0.22367336], + [0.53366841, -0.30312976], + [1.52091956, -0.49283504], + [-0.28162401, -2.10400981], + [0.83680821, 1.72827342], + [0.3084254, 0.33299982], + [0.70472253, -0.73309052], + [0.28893132, -0.38761769], + [1.15514042, 0.0129463], + [0.88407872, 0.35454207], + [1.31301027, -0.92648734], + [-1.11515198, -0.93689695], + [-0.18410027, -0.45194484], + [0.9281014, 0.53085498], + [-0.14374509, 0.27370049], + [-0.41635887, -0.38299653], + [0.08711622, 0.93259929], + [1.70580611, -0.11219234], + [0.47436887, -0.2645749], + [1.07844562, -0.19435291], + [1.44228238, -1.31256615], + [1.25636713, -1.04463226]]) + y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, + 1, 0, 1, 0, 0, 0, 0, 0]) assert_allclose(X_resampled, X_gt, rtol=R_TOL) assert_array_equal(y_resampled, y_gt) +@pytest.mark.filterwarnings('ignore:"kind" is deprecated in 0.4 and will be') +@pytest.mark.filterwarnings('ignore:"m_neighbors" is deprecated in 0.4 and') def test_fit_sample_nn_obj(): kind = 'borderline1' nn_m = NearestNeighbors(n_neighbors=11) @@ -194,9 +209,8 @@ def test_fit_sample_nn_obj(): def test_sample_regular_with_nn(): - kind = 'regular' nn_k = NearestNeighbors(n_neighbors=6) - smote = SMOTE(random_state=RND_SEED, kind=kind, k_neighbors=nn_k) + smote = SMOTE(random_state=RND_SEED, k_neighbors=nn_k) X_resampled, y_resampled = smote.fit_sample(X, Y) X_gt = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [ 1.25192108, -0.22367336 @@ -220,54 +234,72 @@ def test_sample_regular_with_nn(): assert_array_equal(y_resampled, y_gt) +@pytest.mark.filterwarnings('ignore:"kind" is deprecated in 0.4 and will be') +@pytest.mark.filterwarnings('ignore:"m_neighbors" is deprecated in 0.4 and') def test_wrong_nn(): kind = 'borderline1' nn_m = 'rnd' nn_k = NearestNeighbors(n_neighbors=6) smote = SMOTE( random_state=RND_SEED, kind=kind, k_neighbors=nn_k, m_neighbors=nn_m) - with raises(ValueError, match="has to be one of"): + with pytest.raises(ValueError, match="has to be one of"): smote.fit_sample(X, Y) nn_k = 'rnd' nn_m = NearestNeighbors(n_neighbors=10) smote = SMOTE( random_state=RND_SEED, kind=kind, k_neighbors=nn_k, m_neighbors=nn_m) - with raises(ValueError, match="has to be one of"): + with pytest.raises(ValueError, match="has to be one of"): smote.fit_sample(X, Y) kind = 'regular' nn_k = 'rnd' smote = SMOTE(random_state=RND_SEED, kind=kind, k_neighbors=nn_k) - with raises(ValueError, match="has to be one of"): + with pytest.raises(ValueError, match="has to be one of"): smote.fit_sample(X, Y) -def test_sample_regular_with_nn_svm(): +@pytest.mark.filterwarnings('ignore:"kind" is deprecated in 0.4 and will be') +@pytest.mark.filterwarnings('ignore:"svm_estimator" is deprecated in 0.4 and') +@pytest.mark.filterwarnings('ignore:"out_step" is deprecated in 0.4 and') +@pytest.mark.filterwarnings('ignore:"m_neighbors" is deprecated in 0.4 and') +def test_sample_with_nn_svm(): kind = 'svm' nn_k = NearestNeighbors(n_neighbors=6) svm = SVC(random_state=RND_SEED) smote = SMOTE( random_state=RND_SEED, kind=kind, k_neighbors=nn_k, svm_estimator=svm) X_resampled, y_resampled = smote.fit_sample(X, Y) - X_gt = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141], [ - 1.25192108, -0.22367336 - ], [0.53366841, -0.30312976], [1.52091956, -0.49283504], [ - -0.28162401, -2.10400981 - ], [0.83680821, 1.72827342], [0.3084254, 0.33299982], [ - 0.70472253, -0.73309052 - ], [0.28893132, -0.38761769], [1.15514042, 0.0129463], [ - 0.88407872, 0.35454207 - ], [1.31301027, -0.92648734], [-1.11515198, -0.93689695], [ - -0.18410027, -0.45194484 - ], [0.9281014, 0.53085498], [-0.14374509, 0.27370049], - [-0.41635887, -0.38299653], [0.08711622, 0.93259929], - [1.70580611, -0.11219234], [0.47436888, -0.2645749], - [1.07844561, -0.19435291], [1.44015515, -1.30621303]]) - y_gt = np.array( - [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]) + X_gt = np.array([[0.11622591, -0.0317206], + [0.77481731, 0.60935141], + [1.25192108, -0.22367336], + [0.53366841, -0.30312976], + [1.52091956, -0.49283504], + [-0.28162401, -2.10400981], + [0.83680821, 1.72827342], + [0.3084254, 0.33299982], + [0.70472253, -0.73309052], + [0.28893132, -0.38761769], + [1.15514042, 0.0129463], + [0.88407872, 0.35454207], + [1.31301027, -0.92648734], + [-1.11515198, -0.93689695], + [-0.18410027, -0.45194484], + [0.9281014, 0.53085498], + [-0.14374509, 0.27370049], + [-0.41635887, -0.38299653], + [0.08711622, 0.93259929], + [1.70580611, -0.11219234], + [0.47436887, -0.2645749], + [1.07844562, -0.19435291], + [1.44228238, -1.31256615], + [1.25636713, -1.04463226]]) + y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]) assert_allclose(X_resampled, X_gt, rtol=R_TOL) assert_array_equal(y_resampled, y_gt) +@pytest.mark.filterwarnings('ignore:"kind" is deprecated in 0.4 and will be') +@pytest.mark.filterwarnings('ignore:"svm_estimator" is deprecated in 0.4 and') def test_sample_regular_wrong_svm(): kind = 'svm' nn_k = NearestNeighbors(n_neighbors=6) @@ -275,5 +307,39 @@ def test_sample_regular_wrong_svm(): smote = SMOTE( random_state=RND_SEED, kind=kind, k_neighbors=nn_k, svm_estimator=svm) - with raises(ValueError, match="has to be one of"): + with pytest.raises(ValueError, match="has to be one of"): smote.fit_sample(X, Y) + + +def test_borderline_smote_wrong_kind(): + bsmote = BorderlineSMOTE(kind='rand') + with pytest.raises(ValueError, match='The possible "kind" of algorithm'): + bsmote.fit_sample(X, Y) + + +@pytest.mark.parametrize('kind', ['borderline-1', 'borderline-2']) +def test_borderline_smote(kind): + bsmote = BorderlineSMOTE(kind=kind, random_state=42) + bsmote_nn = BorderlineSMOTE(kind=kind, random_state=42, + k_neighbors=NearestNeighbors(n_neighbors=6), + m_neighbors=NearestNeighbors(n_neighbors=11)) + + X_res_1, y_res_1 = bsmote.fit_sample(X, Y) + X_res_2, y_res_2 = bsmote_nn.fit_sample(X, Y) + + assert_allclose(X_res_1, X_res_2) + assert_array_equal(y_res_1, y_res_2) + + +def test_svm_smote(): + svm_smote = SVMSMOTE(random_state=42) + svm_smote_nn = SVMSMOTE(random_state=42, + k_neighbors=NearestNeighbors(n_neighbors=6), + m_neighbors=NearestNeighbors(n_neighbors=11), + svm_estimator=SVC(random_state=42)) + + X_res_1, y_res_1 = svm_smote.fit_sample(X, Y) + X_res_2, y_res_2 = svm_smote_nn.fit_sample(X, Y) + + assert_allclose(X_res_1, X_res_2) + assert_array_equal(y_res_1, y_res_2) diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 3480f55ac..3bb52d46d 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -34,6 +34,8 @@ from imblearn.utils.testing import warns +DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE'] + def _yield_sampler_checks(name, Estimator): yield check_target_type @@ -169,36 +171,37 @@ def check_samplers_fit_sample(name, Sampler): # FIXME remove in 0.6 -> ratio will be deprecated def check_samplers_ratio_fit_sample(name, Sampler): - # in this test we will force all samplers to not change the class 1 - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0) - sampler = Sampler() - expected_stat = Counter(y)[1] - if isinstance(sampler, BaseOverSampler): - ratio = {2: 498, 0: 498} - sampler.set_params(ratio=ratio) - X_res, y_res = sampler.fit_sample(X, y) - assert Counter(y_res)[1] == expected_stat - elif isinstance(sampler, BaseUnderSampler): - ratio = {2: 201, 0: 201} - sampler.set_params(ratio=ratio) - X_res, y_res = sampler.fit_sample(X, y) - assert Counter(y_res)[1] == expected_stat - elif isinstance(sampler, BaseCleaningSampler): - ratio = {2: 201, 0: 201} - sampler.set_params(ratio=ratio) - X_res, y_res = sampler.fit_sample(X, y) - assert Counter(y_res)[1] == expected_stat - if isinstance(sampler, BaseEnsembleSampler): - ratio = {2: 201, 0: 201} - sampler.set_params(ratio=ratio) - X_res, y_res = sampler.fit_sample(X, y) - y_ensemble = y_res[0] - assert Counter(y_ensemble)[1] == expected_stat + if name not in DONT_SUPPORT_RATIO: + # in this test we will force all samplers to not change the class 1 + X, y = make_classification( + n_samples=1000, + n_classes=3, + n_informative=4, + weights=[0.2, 0.3, 0.5], + random_state=0) + sampler = Sampler() + expected_stat = Counter(y)[1] + if isinstance(sampler, BaseOverSampler): + ratio = {2: 498, 0: 498} + sampler.set_params(ratio=ratio) + X_res, y_res = sampler.fit_sample(X, y) + assert Counter(y_res)[1] == expected_stat + elif isinstance(sampler, BaseUnderSampler): + ratio = {2: 201, 0: 201} + sampler.set_params(ratio=ratio) + X_res, y_res = sampler.fit_sample(X, y) + assert Counter(y_res)[1] == expected_stat + elif isinstance(sampler, BaseCleaningSampler): + ratio = {2: 201, 0: 201} + sampler.set_params(ratio=ratio) + X_res, y_res = sampler.fit_sample(X, y) + assert Counter(y_res)[1] == expected_stat + if isinstance(sampler, BaseEnsembleSampler): + ratio = {2: 201, 0: 201} + sampler.set_params(ratio=ratio) + X_res, y_res = sampler.fit_sample(X, y) + y_ensemble = y_res[0] + assert Counter(y_ensemble)[1] == expected_stat def check_samplers_sampling_strategy_fit_sample(name, Sampler):