Skip to content

Commit d1d9546

Browse files
committed
iter
1 parent 77f9c75 commit d1d9546

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

examples/ensemble/plot_bagging_classifier.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,40 +105,56 @@
105105
# %% [markdown]
106106
# Roughly Balanced Bagging
107107
# ------------------------
108-
# FIXME: narration based on [3]_.
108+
# While using a :class:`~imblearn.under_sampling.RandomUnderSampler` or
109+
# :class:`~imblearn.over_sampling.RandomOverSampler` will create exactly the
110+
# desired number of samples, it does not follow the statistical spirit wanted
111+
# in the bagging framework. The authors in [3]_ proposes to use a negative
112+
# binomial distribution to compute the number of samples of the majority
113+
# class to be selected and then perform a random under-sampling.
114+
#
115+
# Here, we illustrate this method by implementing a function in charge of
116+
# resampling and use the :class:`~imblearn.FunctionSampler` to integrate it
117+
# within a :class:`~imblearn.pipeline.Pipeline` and
118+
# :class:`~sklearn.model_selection.cross_validate`.
109119

110120
# %%
111121
from collections import Counter
112122
import numpy as np
113123
from imblearn import FunctionSampler
114124

115125

116-
def binomial_resampling(X, y):
126+
def roughly_balanced_bagging(X, y, replace=False):
127+
"""Implementation of Roughly Balanced Bagging for binary problem."""
128+
# find the minority and majority classes
117129
class_counts = Counter(y)
118130
majority_class = max(class_counts, key=class_counts.get)
119131
minority_class = min(class_counts, key=class_counts.get)
120132

133+
# compute the number of sample to draw from the majority class using
134+
# a negative binomial distribution
121135
n_minority_class = class_counts[minority_class]
122-
n_majority_resampled = np.random.negative_binomial(n_minority_class, 0.5)
136+
n_majority_resampled = np.random.negative_binomial(n=n_minority_class, p=0.5)
123137

138+
# draw randomly with or without replacement
124139
majority_indices = np.random.choice(
125140
np.flatnonzero(y == majority_class),
126141
size=n_majority_resampled,
127-
replace=True,
142+
replace=replace,
128143
)
129144
minority_indices = np.random.choice(
130145
np.flatnonzero(y == minority_class),
131146
size=n_minority_class,
132-
replace=True,
147+
replace=replace,
133148
)
134149
indices = np.hstack([majority_indices, minority_indices])
135150

136-
X_res, y_res = X[indices], y[indices]
137-
return X_res, y_res
151+
return X[indices], y[indices]
138152

139153

140154
# Roughly Balanced Bagging
141-
rbb = BalancedBaggingClassifier(sampler=FunctionSampler(func=binomial_resampling))
155+
rbb = BalancedBaggingClassifier(
156+
sampler=FunctionSampler(func=roughly_balanced_bagging, kw_args={"replace": True})
157+
)
142158
cv_results = cross_validate(rbb, X, y, scoring="balanced_accuracy")
143159

144160
print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}")

imblearn/ensemble/tests/test_bagging.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import pytest
1010

11-
from sklearn.datasets import load_iris, make_hastie_10_2
11+
from sklearn.datasets import load_iris, make_hastie_10_2, make_classification
1212
from sklearn.model_selection import (
1313
GridSearchCV,
1414
ParameterGrid,
@@ -24,6 +24,7 @@
2424
from sklearn.utils._testing import assert_array_almost_equal
2525
from sklearn.utils._testing import assert_allclose
2626

27+
from imblearn import FunctionSampler
2728
from imblearn.datasets import make_imbalance
2829
from imblearn.ensemble import BalancedBaggingClassifier
2930
from imblearn.over_sampling import RandomOverSampler, SMOTE
@@ -550,3 +551,56 @@ def test_balanced_bagging_classifier_samplers(sampler, n_samples_bootstrap):
550551
assert_array_equal(
551552
list(clf.estimators_[0][-1].class_counts_.values()), n_samples_bootstrap
552553
)
554+
555+
556+
@pytest.mark.parametrize("replace", [True, False])
557+
def test_balanced_bagging_classifier_with_function_sampler(replace):
558+
# check that we can provide a FunctionSampler in BalancedBaggingClassifier
559+
X, y = make_classification(
560+
n_samples=1_000,
561+
n_features=10,
562+
n_classes=2,
563+
weights=[0.3, 0.7],
564+
random_state=0,
565+
)
566+
567+
def roughly_balanced_bagging(X, y, replace=False):
568+
"""Implementation of Roughly Balanced Bagging for binary problem."""
569+
# find the minority and majority classes
570+
class_counts = Counter(y)
571+
majority_class = max(class_counts, key=class_counts.get)
572+
minority_class = min(class_counts, key=class_counts.get)
573+
574+
# compute the number of sample to draw from the majority class using
575+
# a negative binomial distribution
576+
n_minority_class = class_counts[minority_class]
577+
n_majority_resampled = np.random.negative_binomial(n=n_minority_class, p=0.5)
578+
579+
# draw randomly with or without replacement
580+
majority_indices = np.random.choice(
581+
np.flatnonzero(y == majority_class),
582+
size=n_majority_resampled,
583+
replace=replace,
584+
)
585+
minority_indices = np.random.choice(
586+
np.flatnonzero(y == minority_class),
587+
size=n_minority_class,
588+
replace=replace,
589+
)
590+
indices = np.hstack([majority_indices, minority_indices])
591+
592+
return X[indices], y[indices]
593+
594+
# Roughly Balanced Bagging
595+
rbb = BalancedBaggingClassifier(
596+
base_estimator=CountDecisionTreeClassifier(),
597+
n_estimators=2,
598+
sampler=FunctionSampler(
599+
func=roughly_balanced_bagging, kw_args={"replace": replace}
600+
),
601+
)
602+
rbb.fit(X, y)
603+
604+
for estimator in rbb.estimators_:
605+
class_counts = estimator[-1].class_counts_
606+
assert (class_counts[0] / class_counts[1]) > 0.9

0 commit comments

Comments
 (0)