Skip to content

Commit b457ab3

Browse files
committed
EHN add BalancedBaggingClassifier
1 parent 333d81b commit b457ab3

File tree

3 files changed

+309
-91
lines changed

3 files changed

+309
-91
lines changed

imblearn/ensemble/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
under-sampled subsets combined inside an ensemble.
44
"""
55

6-
from .easy_ensemble import EasyEnsemble
6+
from .easy_ensemble import EasyEnsemble, BalancedBaggingClassifier
77
from .balance_cascade import BalanceCascade
88

9-
__all__ = ['EasyEnsemble', 'BalanceCascade']
9+
__all__ = ['EasyEnsemble', 'BalancedBaggingClassifier', 'BalanceCascade']

imblearn/ensemble/easy_ensemble.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
# Christos Aridas
55
# License: MIT
66

7+
import numbers
8+
79
import numpy as np
810

11+
from sklearn.ensemble import BaggingClassifier
12+
from sklearn.tree import DecisionTreeClassifier
913
from sklearn.utils import check_random_state
1014

1115
from .base import BaseEnsembleSampler
16+
from ..pipeline import Pipeline
1217
from ..under_sampling import RandomUnderSampler
1318

1419
MAX_INT = np.iinfo(np.int32).max
@@ -147,3 +152,178 @@ def _sample(self, X, y):
147152
np.array(idx_under))
148153
else:
149154
return np.array(X_resampled), np.array(y_resampled)
155+
156+
157+
class BalancedBaggingClassifier(BaggingClassifier):
158+
"""A Bagging classifier with additional balancing.
159+
160+
This implementation of Bagging is similar to the scikit-learn
161+
implementation. It includes an additional step to balance the training set
162+
at fit time using a ``RandomUnderSampler``.
163+
164+
Read more in the :ref:`User Guide <bagging>`.
165+
166+
Parameters
167+
----------
168+
base_estimator : object or None, optional (default=None)
169+
The base estimator to fit on random subsets of the dataset.
170+
If None, then the base estimator is a decision tree.
171+
172+
n_estimators : int, optional (default=10)
173+
The number of base estimators in the ensemble.
174+
175+
max_samples : int or float, optional (default=1.0)
176+
The number of samples to draw from X to train each base estimator.
177+
- If int, then draw `max_samples` samples.
178+
- If float, then draw `max_samples * X.shape[0]` samples.
179+
180+
max_features : int or float, optional (default=1.0)
181+
The number of features to draw from X to train each base estimator.
182+
- If int, then draw `max_features` features.
183+
- If float, then draw `max_features * X.shape[1]` features.
184+
185+
bootstrap : boolean, optional (default=True)
186+
Whether samples are drawn with replacement.
187+
188+
bootstrap_features : boolean, optional (default=False)
189+
Whether features are drawn with replacement.
190+
191+
oob_score : bool
192+
Whether to use out-of-bag samples to estimate
193+
the generalization error.
194+
195+
warm_start : bool, optional (default=False)
196+
When set to True, reuse the solution of the previous call to fit
197+
and add more estimators to the ensemble, otherwise, just fit
198+
a whole new ensemble.
199+
.. versionadded:: 0.17
200+
*warm_start* constructor parameter.
201+
202+
ratio : str, dict, or callable, optional (default='auto')
203+
Ratio to use for resampling the data set.
204+
205+
- If ``str``, has to be one of: (i) ``'minority'``: resample the
206+
minority class; (ii) ``'majority'``: resample the majority class,
207+
(iii) ``'not minority'``: resample all classes apart of the minority
208+
class, (iv) ``'all'``: resample all classes, and (v) ``'auto'``:
209+
correspond to ``'all'`` with for over-sampling methods and ``'not
210+
minority'`` for under-sampling methods. The classes targeted will be
211+
over-sampled or under-sampled to achieve an equal number of sample
212+
with the majority or minority class.
213+
- If ``dict``, the keys correspond to the targeted classes. The values
214+
correspond to the desired number of samples.
215+
- If callable, function taking ``y`` and returns a ``dict``. The keys
216+
correspond to the targeted classes. The values correspond to the
217+
desired number of samples.
218+
219+
replacement : bool, optional (default=False)
220+
Whether or not to sample randomly with replacement or not.
221+
222+
n_jobs : int, optional (default=1)
223+
The number of jobs to run in parallel for both `fit` and `predict`.
224+
If -1, then the number of jobs is set to the number of cores.
225+
226+
random_state : int, RandomState instance or None, optional (default=None)
227+
If int, random_state is the seed used by the random number generator;
228+
If RandomState instance, random_state is the random number generator;
229+
If None, the random number generator is the RandomState instance used
230+
by `np.random`.
231+
232+
verbose : int, optional (default=0)
233+
Controls the verbosity of the building process.
234+
235+
Attributes
236+
----------
237+
base_estimator_ : estimator
238+
The base estimator from which the ensemble is grown.
239+
240+
estimators_ : list of estimators
241+
The collection of fitted base estimators.
242+
243+
estimators_samples_ : list of arrays
244+
The subset of drawn samples (i.e., the in-bag samples) for each base
245+
estimator. Each subset is defined by a boolean mask.
246+
247+
estimators_features_ : list of arrays
248+
The subset of drawn features for each base estimator.
249+
250+
classes_ : array of shape = [n_classes]
251+
The classes labels.
252+
253+
n_classes_ : int or list
254+
The number of classes.
255+
256+
oob_score_ : float
257+
Score of the training dataset obtained using an out-of-bag estimate.
258+
259+
oob_decision_function_ : array of shape = [n_samples, n_classes]
260+
Decision function computed with out-of-bag estimate on the training
261+
set. If n_estimators is small it might be possible that a data point
262+
was never left out during the bootstrap. In this case,
263+
`oob_decision_function_` might contain NaN.
264+
265+
References
266+
----------
267+
.. [1] L. Breiman, "Pasting small votes for classification in large
268+
databases and on-line", Machine Learning, 36(1), 85-103, 1999.
269+
.. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140,
270+
1996.
271+
.. [3] T. Ho, "The random subspace method for constructing decision
272+
forests", Pattern Analysis and Machine Intelligence, 20(8), 832-844,
273+
1998.
274+
.. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine
275+
Learning and Knowledge Discovery in Databases, 346-361, 2012.
276+
277+
"""
278+
def __init__(self,
279+
base_estimator=None,
280+
n_estimators=10,
281+
max_samples=1.0,
282+
max_features=1.0,
283+
bootstrap=True,
284+
bootstrap_features=False,
285+
oob_score=False,
286+
warm_start=False,
287+
ratio='auto',
288+
replacement=False,
289+
n_jobs=1,
290+
random_state=None,
291+
verbose=0):
292+
293+
super(BaggingClassifier, self).__init__(
294+
base_estimator,
295+
n_estimators=n_estimators,
296+
max_samples=max_samples,
297+
max_features=max_features,
298+
bootstrap=bootstrap,
299+
bootstrap_features=bootstrap_features,
300+
oob_score=oob_score,
301+
warm_start=warm_start,
302+
n_jobs=n_jobs,
303+
random_state=random_state,
304+
verbose=verbose)
305+
self.ratio = ratio
306+
self.replacement = replacement
307+
308+
def _validate_estimator(self, default=DecisionTreeClassifier()):
309+
"""Check the estimator and the n_estimator attribute, set the
310+
`base_estimator_` attribute."""
311+
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
312+
raise ValueError("n_estimators must be an integer, "
313+
"got {0}.".format(type(self.n_estimators)))
314+
315+
if self.n_estimators <= 0:
316+
raise ValueError("n_estimators must be greater than zero, "
317+
"got {0}.".format(self.n_estimators))
318+
319+
if self.base_estimator is not None:
320+
self.base_estimator_ = Pipeline(
321+
[('sampler', RandomUnderSampler()),
322+
('estimator', self.base_estimator)])
323+
else:
324+
self.base_estimator_ = Pipeline(
325+
[('sampler', RandomUnderSampler()),
326+
('estimator', default)])
327+
328+
if self.base_estimator_ is None:
329+
raise ValueError("base_estimator cannot be None")

0 commit comments

Comments
 (0)