Skip to content

Commit aa80498

Browse files
committed
FIX max_samples was computed on X instead of X_resampled (#661)
1 parent 3025283 commit aa80498

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

doc/whats_new/v0.6.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
1+
.. _changes_0_6_1:
2+
3+
Version 0.6.1
4+
==============
5+
6+
**In Development**
7+
8+
This is a bug-fix release to primarily resolve some packaging issues in version
9+
0.6.0. It also includes minor documentation improvements and some bug fixes.
10+
11+
Changelog
12+
---------
13+
14+
Bug fixes
15+
.........
16+
17+
- Fix a bug in :class:`imblearn.ensemble.BalancedRandomForestClassifier`
18+
leading to a wrong number of samples used during fitting due `max_samples`
19+
and therefore a bad computation of the OOB score.
20+
:pr:`656` by :user:`Guillaume Lemaitre <glemaitre>`.
21+
122
.. _changes_0_6:
223

324
Version 0.6.0

imblearn/ensemble/_forest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def _local_parallel_build_trees(
5353
X_resampled, y_resampled = sampler.fit_resample(X, y)
5454
if sample_weight is not None:
5555
sample_weight = _safe_indexing(sample_weight, sampler.sample_indices_)
56+
if _get_n_samples_bootstrap is not None:
57+
n_samples_bootstrap = min(n_samples_bootstrap, X_resampled.shape[0])
5658
tree = _parallel_build_trees(
5759
tree,
5860
forest,
@@ -214,6 +216,9 @@ class BalancedRandomForestClassifier(RandomForestClassifier):
214216
- If int, then draw `max_samples` samples.
215217
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
216218
`max_samples` should be in the interval `(0, 1)`.
219+
Be aware that the final number samples used will be the minimum between
220+
the number of samples given in `max_samples` and the number of samples
221+
obtained after resampling.
217222
218223
.. versionadded:: 0.22
219224
Added in `scikit-learn` in 0.22

imblearn/ensemble/tests/test_forest.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
115115
X_train, X_test, y_train, y_test = train_test_split(
116116
X, y, random_state=42, stratify=y
117117
)
118-
est = BalancedRandomForestClassifier(oob_score=True, random_state=0)
118+
est = BalancedRandomForestClassifier(
119+
oob_score=True, random_state=0, n_estimators=1000
120+
)
119121

120122
est.fit(X_train, y_train)
121123
test_score = est.score(X_test, y_test)
@@ -182,14 +184,16 @@ def test_balanced_random_forest_pruning(imbalanced_dataset):
182184
assert n_nodes_no_pruning > n_nodes_pruning
183185

184186

185-
def test_balanced_random_forest_oob_binomial():
187+
@pytest.mark.parametrize("ratio", [0.5, 0.1])
188+
@pytest.mark.filterwarnings("ignore:Some inputs do not have OOB scores")
189+
def test_balanced_random_forest_oob_binomial(ratio):
186190
# Regression test for #655: check that the oob score is closed to 0.5
187191
# a binomial experiment.
188192
rng = np.random.RandomState(42)
189193
n_samples = 1000
190194
X = np.arange(n_samples).reshape(-1, 1)
191-
y = rng.binomial(1, 0.5, size=n_samples)
195+
y = rng.binomial(1, ratio, size=n_samples)
192196

193197
erf = BalancedRandomForestClassifier(oob_score=True, random_state=42)
194198
erf.fit(X, y)
195-
assert np.abs(erf.oob_score_ - 0.5) < 0.05
199+
assert np.abs(erf.oob_score_ - 0.5) < 0.1

0 commit comments

Comments
 (0)