Skip to content

Commit 77f9c75

Browse files
committed
iter
1 parent afa54bd commit 77f9c75

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

examples/ensemble/plot_bagging_classifier.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,48 @@
102102

103103
print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}")
104104

105+
# %% [markdown]
106+
# Roughly Balanced Bagging
107+
# ------------------------
108+
# FIXME: narration based on [3]_.
109+
110+
# %%
111+
from collections import Counter
112+
import numpy as np
113+
from imblearn import FunctionSampler
114+
115+
116+
def binomial_resampling(X, y):
117+
class_counts = Counter(y)
118+
majority_class = max(class_counts, key=class_counts.get)
119+
minority_class = min(class_counts, key=class_counts.get)
120+
121+
n_minority_class = class_counts[minority_class]
122+
n_majority_resampled = np.random.negative_binomial(n_minority_class, 0.5)
123+
124+
majority_indices = np.random.choice(
125+
np.flatnonzero(y == majority_class),
126+
size=n_majority_resampled,
127+
replace=True,
128+
)
129+
minority_indices = np.random.choice(
130+
np.flatnonzero(y == minority_class),
131+
size=n_minority_class,
132+
replace=True,
133+
)
134+
indices = np.hstack([majority_indices, minority_indices])
135+
136+
X_res, y_res = X[indices], y[indices]
137+
return X_res, y_res
138+
139+
140+
# Roughly Balanced Bagging
141+
rbb = BalancedBaggingClassifier(sampler=FunctionSampler(func=binomial_resampling))
142+
cv_results = cross_validate(rbb, X, y, scoring="balanced_accuracy")
143+
144+
print(f"{cv_results['test_score'].mean():.3f} +/- {cv_results['test_score'].std():.3f}")
145+
146+
105147
# %% [markdown]
106148
# .. topic:: References:
107149
#
@@ -111,3 +153,7 @@
111153
# .. [2] S. Wang, and X. Yao. "Diversity analysis on imbalanced data sets by
112154
# using ensemble models." 2009 IEEE symposium on computational
113155
# intelligence and data mining. IEEE, 2009.
156+
#
157+
# .. [3] S. Hido, H. Kashima, and Y. Takahashi. "Roughly balanced bagging
158+
# for imbalanced data." Statistical Analysis and Data Mining: The ASA
159+
# Data Science Journal 2.5‐6 (2009): 412-426.

imblearn/ensemble/_bagging.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,10 @@ def __init__(
246246

247247
def _validate_y(self, y):
248248
y_encoded = super()._validate_y(y)
249-
if isinstance(self.sampling_strategy, dict):
249+
if (
250+
isinstance(self.sampling_strategy, dict)
251+
and self.sampler_._sampling_type != "bypass"
252+
):
250253
self._sampling_strategy = {
251254
np.where(self.classes_ == key)[0][0]: value
252255
for key, value in check_sampling_strategy(
@@ -277,7 +280,8 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
277280
else:
278281
base_estimator = clone(default)
279282

280-
self.sampler_.set_params(sampling_strategy=self._sampling_strategy)
283+
if self.sampler_._sampling_type != "bypass":
284+
self.sampler_.set_params(sampling_strategy=self._sampling_strategy)
281285

282286
self.base_estimator_ = Pipeline(
283287
[

0 commit comments

Comments
 (0)