diff --git a/doc/bibtex/refs.bib b/doc/bibtex/refs.bib index 469d4abe8..fc9d9a475 100644 --- a/doc/bibtex/refs.bib +++ b/doc/bibtex/refs.bib @@ -244,3 +244,32 @@ @article{wilson1997improved pages={1--34}, year={1997} } + +@inproceedings{wang2009diversity, + title={Diversity analysis on imbalanced data sets by using ensemble models}, + author={Wang, Shuo and Yao, Xin}, + booktitle={2009 IEEE symposium on computational intelligence and data mining}, + pages={324--331}, + year={2009}, + organization={IEEE} +} + +@article{hido2009roughly, + title={Roughly balanced bagging for imbalanced data}, + author={Hido, Shohei and Kashima, Hisashi and Takahashi, Yutaka}, + journal={Statistical Analysis and Data Mining: The ASA Data Science Journal}, + volume={2}, + number={5-6}, + pages={412--426}, + year={2009}, + publisher={Wiley Online Library} +} + +@article{maclin1997empirical, + title={An empirical evaluation of bagging and boosting}, + author={Maclin, Richard and Opitz, David}, + journal={AAAI/IAAI}, + volume={1997}, + pages={546--551}, + year={1997} +} diff --git a/doc/ensemble.rst b/doc/ensemble.rst index dc4ca94c7..886a2b02e 100644 --- a/doc/ensemble.rst +++ b/doc/ensemble.rst @@ -18,9 +18,9 @@ Bagging classifier In ensemble classifiers, bagging methods build several estimators on different randomly selected subset of data. In scikit-learn, this classifier is named -``BaggingClassifier``. However, this classifier does not allow to balance each -subset of data. Therefore, when training on imbalanced data set, this -classifier will favor the majority classes:: +:class:`~sklearn.ensemble.BaggingClassifier`. However, this classifier does not +allow to balance each subset of data. Therefore, when training on imbalanced +data set, this classifier will favor the majority classes:: >>> from sklearn.datasets import make_classification >>> X, y = make_classification(n_samples=10000, n_features=2, n_informative=2, @@ -41,14 +41,13 @@ classifier will favor the majority classes:: >>> balanced_accuracy_score(y_test, y_pred) # doctest: +ELLIPSIS 0.77... -:class:`BalancedBaggingClassifier` allows to resample each subset of data -before to train each estimator of the ensemble. In short, it combines the -output of an :class:`EasyEnsemble` sampler with an ensemble of classifiers -(i.e. ``BaggingClassifier``). Therefore, :class:`BalancedBaggingClassifier` -takes the same parameters than the scikit-learn -``BaggingClassifier``. Additionally, there is two additional parameters, -``sampling_strategy`` and ``replacement`` to control the behaviour of the -random under-sampler:: +In :class:`BalancedBaggingClassifier`, each bootstrap sample will be further +resampled to achieve the `sampling_strategy` desired. Therefore, +:class:`BalancedBaggingClassifier` takes the same parameters than the +scikit-learn :class:`~sklearn.ensemble.BaggingClassifier`. In addition, the +sampling is controlled by the parameter `sampler` or the two parameters +`sampling_strategy` and `replacement`, if one wants to use the +:class:`~imblearn.under_sampling.RandomUnderSampler`:: >>> from imblearn.ensemble import BalancedBaggingClassifier >>> bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(), @@ -61,6 +60,12 @@ random under-sampler:: >>> balanced_accuracy_score(y_test, y_pred) # doctest: +ELLIPSIS 0.8... +Changing the `sampler` will give rise to different known implementation +:cite:`maclin1997empirical`, :cite:`hido2009roughly`, +:cite:`wang2009diversity`. You can refer to the following example shows in +practice these different methods: +:ref:`sphx_glr_auto_examples_ensemble_plot_bagging_classifier.py` + .. _forest: Forest of randomized trees @@ -69,8 +74,7 @@ Forest of randomized trees :class:`BalancedRandomForestClassifier` is another ensemble method in which each tree of the forest will be provided a balanced bootstrap sample :cite:`chen2004using`. This class provides all functionality of the -:class:`~sklearn.ensemble.RandomForestClassifier` and notably the -`feature_importances_` attributes:: +:class:`~sklearn.ensemble.RandomForestClassifier`:: >>> from imblearn.ensemble import BalancedRandomForestClassifier >>> brf = BalancedRandomForestClassifier(n_estimators=100, random_state=0) @@ -99,11 +103,11 @@ a boosting iteration :cite:`seiffert2009rusboost`:: >>> balanced_accuracy_score(y_test, y_pred) # doctest: +ELLIPSIS 0... -A specific method which uses ``AdaBoost`` as learners in the bagging classifier -is called EasyEnsemble. The :class:`EasyEnsembleClassifier` allows to bag -AdaBoost learners which are trained on balanced bootstrap samples -:cite:`liu2008exploratory`. Similarly to the :class:`BalancedBaggingClassifier` -API, one can construct the ensemble as:: +A specific method which uses :class:`~sklearn.ensemble.AdaBoostClassifier` as +learners in the bagging classifier is called "EasyEnsemble". The +:class:`EasyEnsembleClassifier` allows to bag AdaBoost learners which are +trained on balanced bootstrap samples :cite:`liu2008exploratory`. Similarly to +the :class:`BalancedBaggingClassifier` API, one can construct the ensemble as:: >>> from imblearn.ensemble import EasyEnsembleClassifier >>> eec = EasyEnsembleClassifier(random_state=0) diff --git a/doc/whats_new/v0.8.rst b/doc/whats_new/v0.8.rst index 343384942..2863c4fc4 100644 --- a/doc/whats_new/v0.8.rst +++ b/doc/whats_new/v0.8.rst @@ -24,6 +24,11 @@ New features only containing categorical features. :pr:`802` by :user:`Guillaume Lemaitre `. +- Add the possibility to pass any type of samplers in + :class:`imblearn.ensemble.BalancedBaggingClassifier` unlocking the + implementation of methods based on resampled bagging. + :pr:`808` by :user:`Guillaume Lemaitre `. + Enhancements ............ diff --git a/examples/ensemble/plot_bagging_classifier.py b/examples/ensemble/plot_bagging_classifier.py new file mode 100644 index 000000000..9e392c506 --- /dev/null +++ b/examples/ensemble/plot_bagging_classifier.py @@ -0,0 +1,175 @@ +""" +================================= +Bagging classifiers using sampler +================================= + +In this example, we show how +:class:`~imblearn.ensemble.BalancedBaggingClassifier` can be used to create a +large variety of classifiers by giving different samplers. + +We will give several examples that have been published in the passed year. +""" + +# Authors: Guillaume Lemaitre +# License: MIT + +# %% +print(__doc__) + +# %% [markdown] +# Generate an imbalanced dataset +# ------------------------------ +# +# For this example, we will create a synthetic dataset using the function +# :func:`~sklearn.datasets.make_classification`. The problem will be a toy +# classification problem with a ratio of 1:9 between the two classes. + +# %% +from sklearn.datasets import make_classification + +X, y = make_classification( + n_samples=10_000, + n_features=10, + weights=[0.1, 0.9], + class_sep=0.5, + random_state=0, +) + +# %% +import pandas as pd + +pd.Series(y).value_counts(normalize=True) + +# %% [markdown] +# In the following sections, we will show a couple of algorithms that have +# been proposed over the years. We intend to illustrate how one can reuse the +# :class:`~imblearn.ensemble.BalancedBaggingClassifier` by passing different +# sampler. + +# %% +from sklearn.model_selection import cross_validate +from sklearn.ensemble import BaggingClassifier + +ebb = BaggingClassifier() +cv_results = cross_validate(ebb, X, y, scoring="balanced_accuracy") + +print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}") + +# %% [markdown] +# Exactly Balanced Bagging and Over-Bagging +# ----------------------------------------- +# +# The :class:`~imblearn.ensemble.BalancedBaggingClassifier` can use in +# conjunction with a :class:`~imblearn.under_sampling.RandomUnderSampler` or +# :class:`~imblearn.over_sampling.RandomOverSampler`. These methods are +# referred as Exactly Balanced Bagging and Over-Bagging, respectively and have +# been proposed first in [1]_. + +# %% +from imblearn.ensemble import BalancedBaggingClassifier +from imblearn.under_sampling import RandomUnderSampler + +# Exactly Balanced Bagging +ebb = BalancedBaggingClassifier(sampler=RandomUnderSampler()) +cv_results = cross_validate(ebb, X, y, scoring="balanced_accuracy") + +print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}") + +# %% +from imblearn.over_sampling import RandomOverSampler + +# Over-bagging +over_bagging = BalancedBaggingClassifier(sampler=RandomOverSampler()) +cv_results = cross_validate(over_bagging, X, y, scoring="balanced_accuracy") + +print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}") + +# %% [markdown] +# SMOTE-Bagging +# ------------- +# +# Instead of using a :class:`~imblearn.over_sampling.RandomOverSampler` that +# make a bootstrap, an alternative is to use +# :class:`~imblearn.over_sampling.SMOTE` as an over-sampler. This is known as +# SMOTE-Bagging [2]_. + +# %% +from imblearn.over_sampling import SMOTE + +# SMOTE-Bagging +smote_bagging = BalancedBaggingClassifier(sampler=SMOTE()) +cv_results = cross_validate(smote_bagging, X, y, scoring="balanced_accuracy") + +print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}") + +# %% [markdown] +# Roughly Balanced Bagging +# ------------------------ +# While using a :class:`~imblearn.under_sampling.RandomUnderSampler` or +# :class:`~imblearn.over_sampling.RandomOverSampler` will create exactly the +# desired number of samples, it does not follow the statistical spirit wanted +# in the bagging framework. The authors in [3]_ proposes to use a negative +# binomial distribution to compute the number of samples of the majority +# class to be selected and then perform a random under-sampling. +# +# Here, we illustrate this method by implementing a function in charge of +# resampling and use the :class:`~imblearn.FunctionSampler` to integrate it +# within a :class:`~imblearn.pipeline.Pipeline` and +# :class:`~sklearn.model_selection.cross_validate`. + +# %% +from collections import Counter +import numpy as np +from imblearn import FunctionSampler + + +def roughly_balanced_bagging(X, y, replace=False): + """Implementation of Roughly Balanced Bagging for binary problem.""" + # find the minority and majority classes + class_counts = Counter(y) + majority_class = max(class_counts, key=class_counts.get) + minority_class = min(class_counts, key=class_counts.get) + + # compute the number of sample to draw from the majority class using + # a negative binomial distribution + n_minority_class = class_counts[minority_class] + n_majority_resampled = np.random.negative_binomial(n=n_minority_class, p=0.5) + + # draw randomly with or without replacement + majority_indices = np.random.choice( + np.flatnonzero(y == majority_class), + size=n_majority_resampled, + replace=replace, + ) + minority_indices = np.random.choice( + np.flatnonzero(y == minority_class), + size=n_minority_class, + replace=replace, + ) + indices = np.hstack([majority_indices, minority_indices]) + + return X[indices], y[indices] + + +# Roughly Balanced Bagging +rbb = BalancedBaggingClassifier( + sampler=FunctionSampler(func=roughly_balanced_bagging, kw_args={"replace": True}) +) +cv_results = cross_validate(rbb, X, y, scoring="balanced_accuracy") + +print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}") + + +# %% [markdown] +# .. topic:: References: +# +# .. [1] R. Maclin, and D. Opitz. "An empirical evaluation of bagging and +# boosting." AAAI/IAAI 1997 (1997): 546-551. +# +# .. [2] S. Wang, and X. Yao. "Diversity analysis on imbalanced data sets by +# using ensemble models." 2009 IEEE symposium on computational +# intelligence and data mining. IEEE, 2009. +# +# .. [3] S. Hido, H. Kashima, and Y. Takahashi. "Roughly balanced bagging +# for imbalanced data." Statistical Analysis and Data Mining: The ASA +# Data Science Journal 2.5‐6 (2009): 412-426. diff --git a/examples/ensemble/plot_comparison_ensemble_classifier.py b/examples/ensemble/plot_comparison_ensemble_classifier.py index 294283f4e..65bcf9965 100644 --- a/examples/ensemble/plot_comparison_ensemble_classifier.py +++ b/examples/ensemble/plot_comparison_ensemble_classifier.py @@ -3,7 +3,7 @@ Compare ensemble classifiers using resampling ============================================= -Ensembling classifiers have shown to improve classification performance compare +Ensemble classifiers have shown to improve classification performance compare to single learner. However, they will be affected by class imbalance. This example shows the benefit of balancing the training set before to learn learners. We are making the comparison with non-balanced ensemble methods. @@ -11,7 +11,6 @@ We make a comparison using the balanced accuracy and geometric mean which are metrics widely used in the literature to evaluate models learned on imbalanced set. - """ # Authors: Guillaume Lemaitre diff --git a/imblearn/ensemble/_bagging.py b/imblearn/ensemble/_bagging.py index 073147038..d773fd127 100644 --- a/imblearn/ensemble/_bagging.py +++ b/imblearn/ensemble/_bagging.py @@ -31,7 +31,11 @@ class BalancedBaggingClassifier(BaggingClassifier): This implementation of Bagging is similar to the scikit-learn implementation. It includes an additional step to balance the training set - at fit time using a ``RandomUnderSampler``. + at fit time using a given sampler. + + This classifier can serves as a basis to implement various methods such as + Exactly Balanced Bagging [6]_, Roughly Balanced Bagging [7]_, + Over-Bagging [6]_, or SMOTE-Bagging [8]_. Read more in the :ref:`User Guide `. @@ -59,6 +63,10 @@ class BalancedBaggingClassifier(BaggingClassifier): bootstrap : bool, default=True Whether samples are drawn with replacement. + .. note:: + Note that this bootstrap will be generated from the resampled + dataset. + bootstrap_features : bool, default=False Whether features are drawn with replacement. @@ -74,7 +82,9 @@ class BalancedBaggingClassifier(BaggingClassifier): {sampling_strategy} replacement : bool, default=False - Whether or not to sample randomly with replacement or not. + Whether or not to randomly sample with replacement or not when + `sampler is None`, corresponding to a + :class:`~imblearn.under_sampling.RandomUnderSampler`. {n_jobs} @@ -83,6 +93,13 @@ class BalancedBaggingClassifier(BaggingClassifier): verbose : int, default=0 Controls the verbosity of the building process. + sampler : sampler object, default=None + The sampler used to balanced the dataset before to bootstrap + (if `bootstrap=True`) and `fit` a base estimator. By default, a + :class:`~imblearn.under_sampling.RandomUnderSampler` is used. + + .. versionadded:: 0.8 + Attributes ---------- base_estimator_ : estimator @@ -151,10 +168,21 @@ class BalancedBaggingClassifier(BaggingClassifier): .. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine Learning and Knowledge Discovery in Databases, 346-361, 2012. - .. [5] Chen, Chao, Andy Liaw, and Leo Breiman. "Using random forest to + .. [5] C. Chen Chao, A. Liaw, and L. Breiman. "Using random forest to learn imbalanced data." University of California, Berkeley 110, 2004. + .. [6] R. Maclin, and D. Opitz. "An empirical evaluation of bagging and + boosting." AAAI/IAAI 1997 (1997): 546-551. + + .. [7] S. Hido, H. Kashima, and Y. Takahashi. "Roughly balanced bagging + for imbalanced data." Statistical Analysis and Data Mining: The ASA + Data Science Journal 2.5‐6 (2009): 412-426. + + .. [8] S. Wang, and X. Yao. "Diversity analysis on imbalanced data sets by + using ensemble models." 2009 IEEE symposium on computational + intelligence and data mining. IEEE, 2009. + Examples -------- >>> from collections import Counter @@ -196,6 +224,7 @@ def __init__( n_jobs=None, random_state=None, verbose=0, + sampler=None, ): super().__init__( @@ -213,16 +242,20 @@ def __init__( ) self.sampling_strategy = sampling_strategy self.replacement = replacement + self.sampler = sampler def _validate_y(self, y): y_encoded = super()._validate_y(y) - if isinstance(self.sampling_strategy, dict): + if ( + isinstance(self.sampling_strategy, dict) + and self.sampler_._sampling_type != "bypass" + ): self._sampling_strategy = { np.where(self.classes_ == key)[0][0]: value for key, value in check_sampling_strategy( self.sampling_strategy, y, - "under-sampling", + self.sampler_._sampling_type, ).items() } else: @@ -247,15 +280,12 @@ def _validate_estimator(self, default=DecisionTreeClassifier()): else: base_estimator = clone(default) + if self.sampler_._sampling_type != "bypass": + self.sampler_.set_params(sampling_strategy=self._sampling_strategy) + self.base_estimator_ = Pipeline( [ - ( - "sampler", - RandomUnderSampler( - sampling_strategy=self._sampling_strategy, - replacement=self.replacement, - ), - ), + ("sampler", self.sampler_), ("classifier", base_estimator), ] ) @@ -277,6 +307,15 @@ def fit(self, X, y): Returns self. """ check_target_type(y) + # the sampler needs to be validated before to call _fit because + # _validate_y is called before _validate_estimator and would require + # to know which type of sampler we are using. + if self.sampler is None: + self.sampler_ = RandomUnderSampler( + replacement=self.replacement, + ) + else: + self.sampler_ = clone(self.sampler) # RandomUnderSampler is not supporting sample_weight. We need to pass # None. return self._fit(X, y, self.max_samples, sample_weight=None) diff --git a/imblearn/ensemble/tests/test_bagging.py b/imblearn/ensemble/tests/test_bagging.py index b889c9dec..f3eff340a 100644 --- a/imblearn/ensemble/tests/test_bagging.py +++ b/imblearn/ensemble/tests/test_bagging.py @@ -3,10 +3,12 @@ # Christos Aridas # License: MIT +from collections import Counter + import numpy as np import pytest -from sklearn.datasets import load_iris, make_hastie_10_2 +from sklearn.datasets import load_iris, make_hastie_10_2, make_classification from sklearn.model_selection import ( GridSearchCV, ParameterGrid, @@ -22,47 +24,60 @@ from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_allclose +from imblearn import FunctionSampler from imblearn.datasets import make_imbalance from imblearn.ensemble import BalancedBaggingClassifier +from imblearn.over_sampling import RandomOverSampler, SMOTE from imblearn.pipeline import make_pipeline -from imblearn.under_sampling import RandomUnderSampler +from imblearn.under_sampling import ClusterCentroids, RandomUnderSampler iris = load_iris() -def test_balanced_bagging_classifier(): - # Check classification for various parameter settings. - X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50}, random_state=0, - ) - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - grid = ParameterGrid( +@pytest.mark.parametrize( + "base_estimator", + [ + None, + DummyClassifier(strategy="prior"), + Perceptron(max_iter=1000, tol=1e-3), + DecisionTreeClassifier(), + KNeighborsClassifier(), + SVC(gamma="scale"), + ], +) +@pytest.mark.parametrize( + "params", + ParameterGrid( { "max_samples": [0.5, 1.0], "max_features": [1, 2, 4], "bootstrap": [True, False], "bootstrap_features": [True, False], } + ), +) +def test_balanced_bagging_classifier(base_estimator, params): + # Check classification for various parameter settings. + X, y = make_imbalance( + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, ) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - for base_estimator in [ - None, - DummyClassifier(strategy="prior"), - Perceptron(max_iter=1000, tol=1e-3), - DecisionTreeClassifier(), - KNeighborsClassifier(), - SVC(gamma="scale"), - ]: - for params in grid: - BalancedBaggingClassifier( - base_estimator=base_estimator, random_state=0, **params - ).fit(X_train, y_train).predict(X_test) + BalancedBaggingClassifier( + base_estimator=base_estimator, random_state=0, **params + ).fit(X_train, y_train).predict(X_test) def test_bootstrap_samples(): # Test that bootstrapping samples generate non-perfect base estimators. X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50}, random_state=0, + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, ) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -95,7 +110,10 @@ def test_bootstrap_samples(): def test_bootstrap_features(): # Test that bootstrapping features may generate duplicate features. X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50}, random_state=0, + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, ) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -125,7 +143,10 @@ def test_bootstrap_features(): def test_probability(): # Predict probabilities. X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50}, random_state=0, + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, ) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -136,11 +157,13 @@ def test_probability(): ).fit(X_train, y_train) assert_array_almost_equal( - np.sum(ensemble.predict_proba(X_test), axis=1), np.ones(len(X_test)), + np.sum(ensemble.predict_proba(X_test), axis=1), + np.ones(len(X_test)), ) assert_array_almost_equal( - ensemble.predict_proba(X_test), np.exp(ensemble.predict_log_proba(X_test)), + ensemble.predict_proba(X_test), + np.exp(ensemble.predict_log_proba(X_test)), ) # Degenerate case, where some classes are missing @@ -152,11 +175,13 @@ def test_probability(): ensemble.fit(X_train, y_train) assert_array_almost_equal( - np.sum(ensemble.predict_proba(X_test), axis=1), np.ones(len(X_test)), + np.sum(ensemble.predict_proba(X_test), axis=1), + np.ones(len(X_test)), ) assert_array_almost_equal( - ensemble.predict_proba(X_test), np.exp(ensemble.predict_log_proba(X_test)), + ensemble.predict_proba(X_test), + np.exp(ensemble.predict_log_proba(X_test)), ) @@ -164,7 +189,10 @@ def test_oob_score_classification(): # Check that oob prediction is a good estimation of the generalization # error. X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50}, random_state=0, + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, ) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -195,7 +223,10 @@ def test_oob_score_classification(): def test_single_estimator(): # Check singleton ensembles. X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50}, random_state=0, + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, ) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -215,42 +246,32 @@ def test_single_estimator(): assert_array_equal(clf1.predict(X_test), clf2.predict(X_test)) -def test_error(): +@pytest.mark.parametrize( + "params", + [ + {"n_estimators": 1.5}, + {"n_estimators": -1}, + {"max_samples": -1}, + {"max_samples": 0.0}, + {"max_samples": 2.0}, + {"max_samples": 1000}, + {"max_samples": "foobar"}, + {"max_features": -1}, + {"max_features": 0.0}, + {"max_features": 2.0}, + {"max_features": 5}, + {"max_features": "foobar"}, + ], +) +def test_balanced_bagging_classifier_error(params): # Test that it gives proper exception on deficient input. X, y = make_imbalance( iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50} ) base = DecisionTreeClassifier() - - # Test n_estimators - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, n_estimators=1.5).fit(X, y) - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, n_estimators=-1).fit(X, y) - - # Test max_samples - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_samples=-1).fit(X, y) - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_samples=0.0).fit(X, y) - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_samples=2.0).fit(X, y) + clf = BalancedBaggingClassifier(base_estimator=base, **params) with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_samples=1000).fit(X, y) - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_samples="foobar").fit(X, y) - - # Test max_features - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_features=-1).fit(X, y) - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_features=0.0).fit(X, y) - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_features=2.0).fit(X, y) - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_features=5).fit(X, y) - with pytest.raises(ValueError): - BalancedBaggingClassifier(base, max_features="foobar").fit(X, y) + clf.fit(X, y) # Test support of decision_function assert not (hasattr(BalancedBaggingClassifier(base).fit(X, y), "decision_function")) @@ -276,7 +297,10 @@ def test_gridsearch(): def test_base_estimator(): # Check base_estimator and its default values. X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50}, random_state=0, + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, ) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -301,10 +325,14 @@ def test_base_estimator(): def test_bagging_with_pipeline(): X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50}, random_state=0, + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, ) estimator = BalancedBaggingClassifier( - make_pipeline(SelectKBest(k=1), DecisionTreeClassifier()), max_features=2, + make_pipeline(SelectKBest(k=1), DecisionTreeClassifier()), + max_features=2, ) estimator.fit(X, y).predict(X) @@ -318,7 +346,9 @@ def test_warm_start(random_state=42): for n_estimators in [5, 10]: if clf_ws is None: clf_ws = BalancedBaggingClassifier( - n_estimators=n_estimators, random_state=random_state, warm_start=True, + n_estimators=n_estimators, + random_state=random_state, + warm_start=True, ) else: clf_ws.set_params(n_estimators=n_estimators) @@ -477,3 +507,100 @@ def test_max_samples_consistency(): ) bagging.fit(X, y) assert bagging._max_samples == max_samples + + +class CountDecisionTreeClassifier(DecisionTreeClassifier): + """DecisionTreeClassifier that will memorize the number of samples seen + at fit.""" + + def fit(self, X, y, sample_weight=None): + self.class_counts_ = Counter(y) + return super().fit(X, y, sample_weight=sample_weight) + + +@pytest.mark.parametrize( + "sampler, n_samples_bootstrap", + [ + (None, 15), + (RandomUnderSampler(), 15), # under-sampling with sample_indices_ + (ClusterCentroids(), 15), # under-sampling without sample_indices_ + (RandomOverSampler(), 40), # over-sampling with sample_indices_ + (SMOTE(), 40), # over-sampling without sample_indices_ + ], +) +def test_balanced_bagging_classifier_samplers(sampler, n_samples_bootstrap): + # check that we can pass any kind of sampler to a bagging classifier + X, y = make_imbalance( + iris.data, + iris.target, + sampling_strategy={0: 20, 1: 25, 2: 50}, + random_state=0, + ) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = BalancedBaggingClassifier( + base_estimator=CountDecisionTreeClassifier(), + n_estimators=2, + sampler=sampler, + random_state=0, + ) + clf.fit(X_train, y_train) + clf.predict(X_test) + + # check that we have balanced class with the right counts of class + # sample depending on the sampling strategy + assert_array_equal( + list(clf.estimators_[0][-1].class_counts_.values()), n_samples_bootstrap + ) + + +@pytest.mark.parametrize("replace", [True, False]) +def test_balanced_bagging_classifier_with_function_sampler(replace): + # check that we can provide a FunctionSampler in BalancedBaggingClassifier + X, y = make_classification( + n_samples=1_000, + n_features=10, + n_classes=2, + weights=[0.3, 0.7], + random_state=0, + ) + + def roughly_balanced_bagging(X, y, replace=False): + """Implementation of Roughly Balanced Bagging for binary problem.""" + # find the minority and majority classes + class_counts = Counter(y) + majority_class = max(class_counts, key=class_counts.get) + minority_class = min(class_counts, key=class_counts.get) + + # compute the number of sample to draw from the majority class using + # a negative binomial distribution + n_minority_class = class_counts[minority_class] + n_majority_resampled = np.random.negative_binomial(n=n_minority_class, p=0.5) + + # draw randomly with or without replacement + majority_indices = np.random.choice( + np.flatnonzero(y == majority_class), + size=n_majority_resampled, + replace=replace, + ) + minority_indices = np.random.choice( + np.flatnonzero(y == minority_class), + size=n_minority_class, + replace=replace, + ) + indices = np.hstack([majority_indices, minority_indices]) + + return X[indices], y[indices] + + # Roughly Balanced Bagging + rbb = BalancedBaggingClassifier( + base_estimator=CountDecisionTreeClassifier(), + n_estimators=2, + sampler=FunctionSampler( + func=roughly_balanced_bagging, kw_args={"replace": replace} + ), + ) + rbb.fit(X, y) + + for estimator in rbb.estimators_: + class_counts = estimator[-1].class_counts_ + assert (class_counts[0] / class_counts[1]) > 0.8 diff --git a/references.bib b/references.bib index 398f9e4c3..c803f0ae9 100644 --- a/references.bib +++ b/references.bib @@ -219,3 +219,32 @@ @article{wilson1997improved pages={1--34}, year={1997} } + +@inproceedings{wang2009diversity, + title={Diversity analysis on imbalanced data sets by using ensemble models}, + author={Wang, Shuo and Yao, Xin}, + booktitle={2009 IEEE symposium on computational intelligence and data mining}, + pages={324--331}, + year={2009}, + organization={IEEE} +} + +@article{hido2009roughly, + title={Roughly balanced bagging for imbalanced data}, + author={Hido, Shohei and Kashima, Hisashi and Takahashi, Yutaka}, + journal={Statistical Analysis and Data Mining: The ASA Data Science Journal}, + volume={2}, + number={5-6}, + pages={412--426}, + year={2009}, + publisher={Wiley Online Library} +} + +@article{maclin1997empirical, + title={An empirical evaluation of bagging and boosting}, + author={Maclin, Richard and Opitz, David}, + journal={AAAI/IAAI}, + volume={1997}, + pages={546--551}, + year={1997} +}