Skip to content

Commit d2328df

Browse files
authored
EHN: Create the EasyEnsembleClassifier (#455)
1 parent bcea35b commit d2328df

File tree

7 files changed

+399
-7
lines changed

7 files changed

+399
-7
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Prototype selection
110110
ensemble.BalanceCascade
111111
ensemble.BalancedBaggingClassifier
112112
ensemble.EasyEnsemble
113+
ensemble.EasyEnsembleClassifier
113114

114115
.. _keras_ref:
115116

doc/ensemble.rst

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ Ensemble of samplers
1111
Samplers
1212
--------
1313

14+
.. warning::
15+
Note that those:class:`EasyEnsemble` is deprecated and you should use
16+
:class:`EasyEnsembleClassifier` instead. :class:`EasyEnsembleClassifier` is
17+
presented in the next section.
18+
1419
An imbalanced data set can be balanced by creating several balanced
1520
subsets. The module :mod:`imblearn.ensemble` allows to create such sets.
1621

@@ -92,8 +97,8 @@ output of an :class:`EasyEnsemble` sampler with an ensemble of classifiers
9297
(i.e. ``BaggingClassifier``). Therefore, :class:`BalancedBaggingClassifier`
9398
takes the same parameters than the scikit-learn
9499
``BaggingClassifier``. Additionally, there is two additional parameters,
95-
``sampling_strategy`` and ``replacement``, as in the :class:`EasyEnsemble`
96-
sampler::
100+
``sampling_strategy`` and ``replacement`` to control the behaviour of the
101+
random under-sampler::
97102

98103

99104
>>> from imblearn.ensemble import BalancedBaggingClassifier
@@ -127,3 +132,19 @@ each tree::
127132

128133
See
129134
:ref:`sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py`.
135+
136+
A specific method which uses ``AdaBoost`` as learners in the bagging
137+
classifier is called EasyEnsemble. The :class:`EasyEnsembleClassifier` allows
138+
to bag AdaBoost learners which are trained on balanced bootstrap samples.
139+
Similarly to the :class:`BalancedBaggingClassifier` API, one can construct
140+
the ensemble as::
141+
142+
>>> from imblearn.ensemble import EasyEnsembleClassifier
143+
>>> eec = EasyEnsembleClassifier(random_state=0)
144+
>>> eec.fit(X_train, y_train) # doctest: +ELLIPSIS
145+
EasyEnsembleClassifier(...)
146+
>>> y_pred = eec.predict(X_test)
147+
>>> confusion_matrix(y_test, y_pred)
148+
array([[ 9, 1, 2],
149+
[ 5, 52, 2],
150+
[252, 45, 882]])

doc/whats_new/v0.0.4.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ New features
2424
- Add a ``keras`` and ``tensorflow`` modules to create balanced mini-batches
2525
generator. :issue:`409` by :user:`Guillaume Lemaitre <glemaitre>`.
2626

27+
- Add :class:`imblearn.ensemble.EasyEnsembleClassifier` which create a bag of
28+
AdaBoost classifier trained on balanced bootstrap samples.
29+
:issue:`455` by :user:`Guillaume Lemaitre <glemaitre>`.
30+
2731
Enhancement
2832
...........
2933

@@ -109,3 +113,8 @@ Deprecation
109113
:class:`imblearn.over_sampling.SVMSMOTE` and
110114
:class:`imblearn.over_sampling.BorderlineSMOTE`.
111115
:issue:`440` by :user:`Guillaume Lemaitre <glemaitre>`.
116+
117+
- Deprecate :class:`imblearn.ensemble.EasyEnsemble` in favor of meta-estimator
118+
:class:`imblearn.ensemble.EasyEnsembleClassifier` which follow the exact
119+
algorithm described in the literature.
120+
:issue:`455` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/ensemble/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
"""
55

66
from ._easy_ensemble import EasyEnsemble
7+
from ._easy_ensemble import EasyEnsembleClassifier
78
from ._balance_cascade import BalanceCascade
89

910
from ._classifier import BalancedBaggingClassifier
1011

11-
__all__ = ['EasyEnsemble', 'BalancedBaggingClassifier', 'BalanceCascade']
12+
__all__ = ['EasyEnsemble', 'EasyEnsembleClassifier',
13+
'BalancedBaggingClassifier', 'BalanceCascade']

imblearn/ensemble/_classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(self,
194194
verbose=0,
195195
ratio=None):
196196

197-
super(BaggingClassifier, self).__init__(
197+
super(BalancedBaggingClassifier, self).__init__(
198198
base_estimator,
199199
n_estimators=n_estimators,
200200
max_samples=max_samples,
@@ -237,10 +237,10 @@ def fit(self, X, y):
237237
238238
Parameters
239239
----------
240-
X : array-like of shape = [n_samples, n_features]
240+
X : {array-like, sparse matrix}, shape (n_samples, n_features)
241241
The training input samples.
242242
243-
y : array-like, shape = [n_samples]
243+
y : array-like, shape (n_samples,)
244244
The target values.
245245
246246
Returns

imblearn/ensemble/_easy_ensemble.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,41 @@
44
# Christos Aridas
55
# License: MIT
66

7+
import numbers
8+
79
import numpy as np
810

11+
from sklearn.base import clone
912
from sklearn.utils import check_random_state
13+
from sklearn.ensemble import AdaBoostClassifier
14+
from sklearn.ensemble.bagging import BaggingClassifier
15+
from sklearn.utils.deprecation import deprecated
1016

1117
from .base import BaseEnsembleSampler
1218
from ..under_sampling import RandomUnderSampler
1319
from ..under_sampling.base import BaseUnderSampler
1420
from ..utils import Substitution
1521
from ..utils._docstring import _random_state_docstring
22+
from ..pipeline import Pipeline
1623

1724
MAX_INT = np.iinfo(np.int32).max
1825

1926

2027
@Substitution(
2128
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
2229
random_state=_random_state_docstring)
30+
@deprecated('EasyEnsemble is deprecated in 0.4 and will be removed in 0.6. '
31+
'Use EasyEnsembleClassifier instead.')
2332
class EasyEnsemble(BaseEnsembleSampler):
2433
"""Create an ensemble sets by iteratively applying random under-sampling.
2534
2635
This method iteratively select a random subset and make an ensemble of the
2736
different sets.
2837
38+
.. deprecated:: 0.4
39+
``EasyEnsemble`` is deprecated in 0.4 and will be removed in 0.6. Use
40+
``EasyEnsembleClassifier`` instead.
41+
2942
Read more in the :ref:`User Guide <ensemble_samplers>`.
3043
3144
Parameters
@@ -126,3 +139,161 @@ def _sample(self, X, y):
126139
np.array(idx_under))
127140
else:
128141
return np.array(X_resampled), np.array(y_resampled)
142+
143+
144+
@Substitution(
145+
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
146+
random_state=_random_state_docstring)
147+
class EasyEnsembleClassifier(BaggingClassifier):
148+
"""Bag of balanced boosted learners also known as EasyEnsemble.
149+
150+
This algorithm is known as EasyEnsemble [1]_. The classifier is an
151+
ensemble of AdaBoost learners trained on different balanced boostrap
152+
samples. The balancing is achieved by random under-sampling.
153+
154+
Read more in the :ref:`User Guide <ensemble_samplers>`.
155+
156+
Parameters
157+
----------
158+
n_estimators : int, optional (default=10)
159+
Number of AdaBoost learners in the ensemble.
160+
161+
base_estimator : object, optional (default=AdaBoostClassifier())
162+
The base AdaBoost classifier used in the inner ensemble. Note that you
163+
can set the number of inner learner by passing your own instance.
164+
165+
warm_start : bool, optional (default=False)
166+
When set to True, reuse the solution of the previous call to fit
167+
and add more estimators to the ensemble, otherwise, just fit
168+
a whole new ensemble.
169+
170+
{sampling_strategy}
171+
172+
replacement : bool, optional (default=False)
173+
Whether or not to sample randomly with replacement or not.
174+
175+
n_jobs : int, optional (default=1)
176+
The number of jobs to run in parallel for both `fit` and `predict`.
177+
If -1, then the number of jobs is set to the number of cores.
178+
179+
{random_state}
180+
181+
verbose : int, optional (default=0)
182+
Controls the verbosity of the building process.
183+
184+
Attributes
185+
----------
186+
base_estimator_ : estimator
187+
The base estimator from which the ensemble is grown.
188+
189+
estimators_ : list of estimators
190+
The collection of fitted base estimators.
191+
192+
classes_ : array, shape (n_classes,)
193+
The classes labels.
194+
195+
n_classes_ : int or list
196+
The number of classes.
197+
198+
Notes
199+
-----
200+
The method is described in [1]_.
201+
202+
Supports multi-class resampling by sampling each class independently.
203+
204+
See also
205+
--------
206+
BalanceCascade, BalancedBaggingClassifier
207+
208+
References
209+
----------
210+
.. [1] X. Y. Liu, J. Wu and Z. H. Zhou, "Exploratory Undersampling for
211+
Class-Imbalance Learning," in IEEE Transactions on Systems, Man, and
212+
Cybernetics, Part B (Cybernetics), vol. 39, no. 2, pp. 539-550,
213+
April 2009.
214+
215+
Examples
216+
--------
217+
218+
>>> from collections import Counter
219+
>>> from sklearn.datasets import make_classification
220+
>>> from sklearn.model_selection import train_test_split
221+
>>> from sklearn.metrics import confusion_matrix
222+
>>> from imblearn.ensemble import \
223+
EasyEnsembleClassifier # doctest: +NORMALIZE_WHITESPACE
224+
>>> X, y = make_classification(n_classes=2, class_sep=2,
225+
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
226+
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
227+
>>> print('Original dataset shape %s' % Counter(y))
228+
Original dataset shape Counter({{1: 900, 0: 100}})
229+
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
230+
... random_state=0)
231+
>>> eec = EasyEnsembleClassifier(random_state=42)
232+
>>> eec.fit(X_train, y_train) # doctest: +ELLIPSIS
233+
EasyEnsembleClassifier(...)
234+
>>> y_pred = eec.predict(X_test)
235+
>>> print(confusion_matrix(y_test, y_pred))
236+
[[ 23 0]
237+
[ 2 225]]
238+
239+
"""
240+
def __init__(self, n_estimators=10, base_estimator=None, warm_start=False,
241+
sampling_strategy='auto', replacement=False, n_jobs=1,
242+
random_state=None, verbose=0):
243+
super(EasyEnsembleClassifier, self).__init__(
244+
base_estimator,
245+
n_estimators=n_estimators,
246+
max_samples=1.0,
247+
max_features=1.0,
248+
bootstrap=False,
249+
bootstrap_features=False,
250+
oob_score=False,
251+
warm_start=warm_start,
252+
n_jobs=n_jobs,
253+
random_state=random_state,
254+
verbose=verbose)
255+
self.sampling_strategy = sampling_strategy
256+
self.replacement = replacement
257+
258+
def _validate_estimator(self, default=AdaBoostClassifier()):
259+
"""Check the estimator and the n_estimator attribute, set the
260+
`base_estimator_` attribute."""
261+
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
262+
raise ValueError("n_estimators must be an integer, "
263+
"got {0}.".format(type(self.n_estimators)))
264+
265+
if self.n_estimators <= 0:
266+
raise ValueError("n_estimators must be greater than zero, "
267+
"got {0}.".format(self.n_estimators))
268+
269+
if self.base_estimator is not None:
270+
base_estimator = clone(self.base_estimator)
271+
else:
272+
base_estimator = clone(default)
273+
274+
self.base_estimator_ = Pipeline(
275+
[('sampler', RandomUnderSampler(
276+
sampling_strategy=self.sampling_strategy,
277+
replacement=self.replacement)),
278+
('classifier', base_estimator)])
279+
280+
def fit(self, X, y):
281+
"""Build a Bagging ensemble of AdaBoost classifier using balanced
282+
boostrasp with random under-sampling.
283+
284+
Parameters
285+
----------
286+
X : {array-like, sparse matrix}, shape (n_samples, n_features)
287+
The training input samples.
288+
289+
y : array-like, shape (n_samples,)
290+
The target values.
291+
292+
Returns
293+
-------
294+
self : object
295+
Returns self.
296+
"""
297+
# RandomUnderSampler is not supporting sample_weight. We need to pass
298+
# None.
299+
return self._fit(X, y, self.max_samples, sample_weight=None)

0 commit comments

Comments
 (0)