Skip to content

Commit e2fee9a

Browse files
chkoarglemaitre
authored andcommitted
FIX use a stump as base estimator in RUSBoostClassifier (#545)
1 parent 56fb7d2 commit e2fee9a

File tree

4 files changed

+28
-20
lines changed

4 files changed

+28
-20
lines changed

doc/ensemble.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,13 @@ Several methods taking advantage of boosting have been designed.
9393
a boosting iteration [SKHN2010]_::
9494

9595
>>> from imblearn.ensemble import RUSBoostClassifier
96-
>>> rusboost = RUSBoostClassifier(random_state=0)
96+
>>> rusboost = RUSBoostClassifier(n_estimators=200, algorithm='SAMME.R',
97+
... random_state=0)
9798
>>> rusboost.fit(X_train, y_train) # doctest: +ELLIPSIS
9899
RUSBoostClassifier(...)
99100
>>> y_pred = rusboost.predict(X_test)
100101
>>> balanced_accuracy_score(y_test, y_pred) # doctest: +ELLIPSIS
101-
0.74...
102+
0.66...
102103

103104
A specific method which uses ``AdaBoost`` as learners in the bagging classifier
104105
is called EasyEnsemble. The :class:`EasyEnsembleClassifier` allows to bag

doc/whats_new/v0.5.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@ Version 0.5 (under development)
66
Changelog
77
---------
88

9+
Changed models
10+
..............
11+
12+
The following models or function might give different results even if the
13+
same data ``X`` and ``y`` are the same.
14+
15+
* :class:`imblearn.ensemble.RUSBoostClassifier` default estimator changed from
16+
:class:`sklearn.tree.DecisionTreeClassifier` with full depth to a decision
17+
stump (i.e., tree with ``max_depth=1``).
18+
919
Documentation
1020
.............
1121

@@ -53,3 +63,8 @@ Bug
5363
- Fix bug in :class:`imblearn.pipeline.Pipeline` where None could be the final
5464
estimator.
5565
:pr:`554` by :user:`Oliver Rausch <orausch>`.
66+
67+
- Fix bug by changing the default depth in
68+
:class:`imblearn.ensemble.RUSBoostClassifier` to get a decision stump as a
69+
weak learner as in the original paper.
70+
:pr:`545` by :user:`Christos Aridas <chkoar>`.

imblearn/ensemble/_weight_boosting.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ class RUSBoostClassifier(AdaBoostClassifier):
3030
3131
Parameters
3232
----------
33-
base_estimator : object, optional (default=DecisionTreeClassifier)
33+
base_estimator : object, optional (default=None)
3434
The base estimator from which the boosted ensemble is built.
35-
Support for sample weighting is required, as well as proper `classes_`
36-
and `n_classes_` attributes.
35+
Support for sample weighting is required, as well as proper
36+
``classes_`` and ``n_classes_`` attributes. If ``None``, then
37+
the base estimator is ``DecisionTreeClassifier(max_depth=1)``
3738
3839
n_estimators : integer, optional (default=50)
3940
The maximum number of estimators at which boosting is terminated.
@@ -152,21 +153,10 @@ def fit(self, X, y, sample_weight=None):
152153
super().fit(X, y, sample_weight)
153154
return self
154155

155-
def _validate_estimator(self, default=DecisionTreeClassifier()):
156+
def _validate_estimator(self):
156157
"""Check the estimator and the n_estimator attribute, set the
157158
`base_estimator_` attribute."""
158-
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
159-
raise ValueError("n_estimators must be an integer, "
160-
"got {}.".format(type(self.n_estimators)))
161-
162-
if self.n_estimators <= 0:
163-
raise ValueError("n_estimators must be greater than zero, "
164-
"got {}.".format(self.n_estimators))
165-
166-
if self.base_estimator is not None:
167-
self.base_estimator_ = clone(self.base_estimator)
168-
else:
169-
self.base_estimator_ = clone(default)
159+
super()._validate_estimator()
170160

171161
self.base_sampler_ = RandomUnderSampler(
172162
sampling_strategy=self.sampling_strategy,

imblearn/ensemble/tests/test_weight_boosting.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@pytest.fixture
1313
def imbalanced_dataset():
14-
return make_classification(n_samples=10000, n_features=2, n_informative=2,
14+
return make_classification(n_samples=10000, n_features=3, n_informative=2,
1515
n_redundant=0, n_repeated=0, n_classes=3,
1616
n_clusters_per_class=1,
1717
weights=[0.01, 0.05, 0.94], class_sep=0.8,
@@ -32,7 +32,9 @@ def test_rusboost_error(imbalanced_dataset, boosting_params, err_msg):
3232
@pytest.mark.parametrize('algorithm', ['SAMME', 'SAMME.R'])
3333
def test_rusboost(imbalanced_dataset, algorithm):
3434
X, y = imbalanced_dataset
35-
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)
35+
X_train, X_test, y_train, y_test = train_test_split(X, y,
36+
stratify=y,
37+
random_state=1)
3638
classes = np.unique(y)
3739

3840
n_estimators = 500

0 commit comments

Comments
 (0)