Skip to content

Commit 3aeddf3

Browse files
committed
EHN add BalancedBaggingClassifier
1 parent 333d81b commit 3aeddf3

File tree

3 files changed

+644
-5
lines changed

3 files changed

+644
-5
lines changed

imblearn/ensemble/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
under-sampled subsets combined inside an ensemble.
44
"""
55

6-
from .easy_ensemble import EasyEnsemble
6+
from .easy_ensemble import EasyEnsemble, BalancedBaggingClassifier
77
from .balance_cascade import BalanceCascade
88

9-
__all__ = ['EasyEnsemble', 'BalanceCascade']
9+
__all__ = ['EasyEnsemble', 'BalancedBaggingClassifier', 'BalanceCascade']

imblearn/ensemble/easy_ensemble.py

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,44 @@
44
# Christos Aridas
55
# License: MIT
66

7+
import numbers
8+
79
import numpy as np
810

9-
from sklearn.utils import check_random_state
11+
import sklearn
12+
from sklearn.base import clone
13+
from sklearn.ensemble import BaggingClassifier
14+
from sklearn.ensemble.bagging import _generate_bagging_indices
15+
from sklearn.tree import DecisionTreeClassifier
16+
from sklearn.utils import check_random_state, indices_to_mask
1017

1118
from .base import BaseEnsembleSampler
19+
from ..pipeline import Pipeline
1220
from ..under_sampling import RandomUnderSampler
1321

1422
MAX_INT = np.iinfo(np.int32).max
1523

1624

25+
old_generate = _generate_bagging_indices
26+
27+
28+
def _masked_bagging_indices(random_state, bootstrap_features,
29+
bootstrap_samples, n_features, n_samples,
30+
max_features, max_samples):
31+
"""Monkey-patch to always get a mask instead of indices"""
32+
feature_indices, sample_indices = old_generate(random_state,
33+
bootstrap_features,
34+
bootstrap_samples,
35+
n_features, n_samples,
36+
max_features, max_samples)
37+
sample_indices = indices_to_mask(sample_indices, n_samples)
38+
39+
return feature_indices, sample_indices
40+
41+
42+
sklearn.ensemble.bagging._generate_bagging_indices = _masked_bagging_indices
43+
44+
1745
class EasyEnsemble(BaseEnsembleSampler):
1846
"""Create an ensemble sets by iteratively applying random under-sampling.
1947
@@ -147,3 +175,176 @@ def _sample(self, X, y):
147175
np.array(idx_under))
148176
else:
149177
return np.array(X_resampled), np.array(y_resampled)
178+
179+
180+
class BalancedBaggingClassifier(BaggingClassifier):
181+
"""A Bagging classifier with additional balancing.
182+
183+
This implementation of Bagging is similar to the scikit-learn
184+
implementation. It includes an additional step to balance the training set
185+
at fit time using a ``RandomUnderSampler``.
186+
187+
Read more in the :ref:`User Guide <bagging>`.
188+
189+
Parameters
190+
----------
191+
base_estimator : object or None, optional (default=None)
192+
The base estimator to fit on random subsets of the dataset.
193+
If None, then the base estimator is a decision tree.
194+
195+
n_estimators : int, optional (default=10)
196+
The number of base estimators in the ensemble.
197+
198+
max_samples : int or float, optional (default=1.0)
199+
The number of samples to draw from X to train each base estimator.
200+
- If int, then draw `max_samples` samples.
201+
- If float, then draw `max_samples * X.shape[0]` samples.
202+
203+
max_features : int or float, optional (default=1.0)
204+
The number of features to draw from X to train each base estimator.
205+
- If int, then draw `max_features` features.
206+
- If float, then draw `max_features * X.shape[1]` features.
207+
208+
bootstrap : boolean, optional (default=True)
209+
Whether samples are drawn with replacement.
210+
211+
bootstrap_features : boolean, optional (default=False)
212+
Whether features are drawn with replacement.
213+
214+
oob_score : bool
215+
Whether to use out-of-bag samples to estimate
216+
the generalization error.
217+
218+
warm_start : bool, optional (default=False)
219+
When set to True, reuse the solution of the previous call to fit
220+
and add more estimators to the ensemble, otherwise, just fit
221+
a whole new ensemble.
222+
.. versionadded:: 0.17
223+
*warm_start* constructor parameter.
224+
225+
ratio : str, dict, or callable, optional (default='auto')
226+
Ratio to use for resampling the data set.
227+
228+
- If ``str``, has to be one of: (i) ``'minority'``: resample the
229+
minority class; (ii) ``'majority'``: resample the majority class,
230+
(iii) ``'not minority'``: resample all classes apart of the minority
231+
class, (iv) ``'all'``: resample all classes, and (v) ``'auto'``:
232+
correspond to ``'all'`` with for over-sampling methods and ``'not
233+
minority'`` for under-sampling methods. The classes targeted will be
234+
over-sampled or under-sampled to achieve an equal number of sample
235+
with the majority or minority class.
236+
- If ``dict``, the keys correspond to the targeted classes. The values
237+
correspond to the desired number of samples.
238+
- If callable, function taking ``y`` and returns a ``dict``. The keys
239+
correspond to the targeted classes. The values correspond to the
240+
desired number of samples.
241+
242+
replacement : bool, optional (default=False)
243+
Whether or not to sample randomly with replacement or not.
244+
245+
n_jobs : int, optional (default=1)
246+
The number of jobs to run in parallel for both `fit` and `predict`.
247+
If -1, then the number of jobs is set to the number of cores.
248+
249+
random_state : int, RandomState instance or None, optional (default=None)
250+
If int, random_state is the seed used by the random number generator;
251+
If RandomState instance, random_state is the random number generator;
252+
If None, the random number generator is the RandomState instance used
253+
by `np.random`.
254+
255+
verbose : int, optional (default=0)
256+
Controls the verbosity of the building process.
257+
258+
Attributes
259+
----------
260+
base_estimator_ : estimator
261+
The base estimator from which the ensemble is grown.
262+
263+
estimators_ : list of estimators
264+
The collection of fitted base estimators.
265+
266+
estimators_samples_ : list of arrays
267+
The subset of drawn samples (i.e., the in-bag samples) for each base
268+
estimator. Each subset is defined by a boolean mask.
269+
270+
estimators_features_ : list of arrays
271+
The subset of drawn features for each base estimator.
272+
273+
classes_ : array of shape = [n_classes]
274+
The classes labels.
275+
276+
n_classes_ : int or list
277+
The number of classes.
278+
279+
oob_score_ : float
280+
Score of the training dataset obtained using an out-of-bag estimate.
281+
282+
oob_decision_function_ : array of shape = [n_samples, n_classes]
283+
Decision function computed with out-of-bag estimate on the training
284+
set. If n_estimators is small it might be possible that a data point
285+
was never left out during the bootstrap. In this case,
286+
`oob_decision_function_` might contain NaN.
287+
288+
References
289+
----------
290+
.. [1] L. Breiman, "Pasting small votes for classification in large
291+
databases and on-line", Machine Learning, 36(1), 85-103, 1999.
292+
.. [2] L. Breiman, "Bagging predictors", Machine Learning, 24(2), 123-140,
293+
1996.
294+
.. [3] T. Ho, "The random subspace method for constructing decision
295+
forests", Pattern Analysis and Machine Intelligence, 20(8), 832-844,
296+
1998.
297+
.. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine
298+
Learning and Knowledge Discovery in Databases, 346-361, 2012.
299+
300+
"""
301+
def __init__(self,
302+
base_estimator=None,
303+
n_estimators=10,
304+
max_samples=1.0,
305+
max_features=1.0,
306+
bootstrap=True,
307+
bootstrap_features=False,
308+
oob_score=False,
309+
warm_start=False,
310+
ratio='auto',
311+
replacement=False,
312+
n_jobs=1,
313+
random_state=None,
314+
verbose=0):
315+
316+
super(BaggingClassifier, self).__init__(
317+
base_estimator,
318+
n_estimators=n_estimators,
319+
max_samples=max_samples,
320+
max_features=max_features,
321+
bootstrap=bootstrap,
322+
bootstrap_features=bootstrap_features,
323+
oob_score=oob_score,
324+
warm_start=warm_start,
325+
n_jobs=n_jobs,
326+
random_state=random_state,
327+
verbose=verbose)
328+
self.ratio = ratio
329+
self.replacement = replacement
330+
331+
def _validate_estimator(self, default=DecisionTreeClassifier()):
332+
"""Check the estimator and the n_estimator attribute, set the
333+
`base_estimator_` attribute."""
334+
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
335+
raise ValueError("n_estimators must be an integer, "
336+
"got {0}.".format(type(self.n_estimators)))
337+
338+
if self.n_estimators <= 0:
339+
raise ValueError("n_estimators must be greater than zero, "
340+
"got {0}.".format(self.n_estimators))
341+
342+
if self.base_estimator is not None:
343+
base_estimator = clone(self.base_estimator)
344+
else:
345+
base_estimator = clone(default)
346+
347+
self.base_estimator_ = Pipeline(
348+
[('sampler', RandomUnderSampler(ratio=self.ratio,
349+
replacement=self.replacement)),
350+
('classifier', base_estimator)])

0 commit comments

Comments
 (0)