Skip to content

Commit d8472f4

Browse files
authored
FIX incorporate resampling when computing OOB score in BRF (#656)
1 parent 3839df1 commit d8472f4

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

doc/whats_new/v0.6.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ Bug fixes
2828
`cross_val_predict` is used to take advantage of the parallelism.
2929
:pr:`599` by :user:`Shihab Shahriar Khan <Shihab-Shahriar>`.
3030

31+
- Fix a bug in :class:`imblearn.ensemble.BalancedRandomForestClassifier`
32+
leading to a wrong computation of the OOB score.
33+
:pr:`656` by :user:`Guillaume Lemaitre <glemaitre>`.
34+
3135
Maintenance
3236
...........
3337

imblearn/ensemble/_forest.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sklearn.ensemble._base import _set_random_states
2020
from sklearn.ensemble._forest import _get_n_samples_bootstrap
2121
from sklearn.ensemble._forest import _parallel_build_trees
22+
from sklearn.ensemble._forest import _generate_unsampled_indices
2223
from sklearn.exceptions import DataConversionWarning
2324
from sklearn.tree import DecisionTreeClassifier
2425
from sklearn.utils import check_array
@@ -545,5 +546,65 @@ def fit(self, X, y, sample_weight=None):
545546

546547
return self
547548

549+
def _set_oob_score(self, X, y):
550+
"""Compute out-of-bag score."""
551+
X = check_array(X, dtype=DTYPE, accept_sparse='csr')
552+
553+
n_classes_ = self.n_classes_
554+
n_samples = y.shape[0]
555+
556+
oob_decision_function = []
557+
oob_score = 0.0
558+
predictions = [np.zeros((n_samples, n_classes_[k]))
559+
for k in range(self.n_outputs_)]
560+
561+
for sampler, estimator in zip(self.samplers_, self.estimators_):
562+
X_resample = X[sampler.sample_indices_]
563+
y_resample = y[sampler.sample_indices_]
564+
565+
n_sample_subset = y_resample.shape[0]
566+
n_samples_bootstrap = _get_n_samples_bootstrap(
567+
n_sample_subset, self.max_samples
568+
)
569+
570+
unsampled_indices = _generate_unsampled_indices(
571+
estimator.random_state, n_sample_subset, n_samples_bootstrap
572+
)
573+
p_estimator = estimator.predict_proba(
574+
X_resample[unsampled_indices, :], check_input=False
575+
)
576+
577+
if self.n_outputs_ == 1:
578+
p_estimator = [p_estimator]
579+
580+
for k in range(self.n_outputs_):
581+
indices = sampler.sample_indices_[unsampled_indices]
582+
predictions[k][indices, :] += p_estimator[k]
583+
584+
for k in range(self.n_outputs_):
585+
if (predictions[k].sum(axis=1) == 0).any():
586+
warn("Some inputs do not have OOB scores. "
587+
"This probably means too few trees were used "
588+
"to compute any reliable oob estimates.")
589+
590+
with np.errstate(invalid="ignore", divide="ignore"):
591+
# with the resampling, we are likely to have rows not included
592+
# for the OOB score leading to division by zero
593+
decision = (predictions[k] /
594+
predictions[k].sum(axis=1)[:, np.newaxis])
595+
mask_scores = np.isnan(np.sum(decision, axis=1))
596+
oob_decision_function.append(decision)
597+
oob_score += np.mean(
598+
y[~mask_scores, k] == np.argmax(predictions[k][~mask_scores],
599+
axis=1),
600+
axis=0)
601+
602+
if self.n_outputs_ == 1:
603+
self.oob_decision_function_ = oob_decision_function[0]
604+
else:
605+
self.oob_decision_function_ = oob_decision_function
606+
607+
self.oob_score_ = oob_score / self.n_outputs_
608+
548609
def _more_tags(self):
549610
return {"multioutput": False}

imblearn/ensemble/tests/test_forest.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from sklearn.datasets import make_classification
66
from sklearn.model_selection import GridSearchCV
7+
from sklearn.model_selection import train_test_split
78
from sklearn.utils._testing import assert_allclose
89
from sklearn.utils._testing import assert_array_equal
910

@@ -108,13 +109,16 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset):
108109
brf.fit(X, y, sample_weight)
109110

110111

112+
@pytest.mark.filterwarnings("ignore:Some inputs do not have OOB scores")
111113
def test_balanced_random_forest_oob(imbalanced_dataset):
112114
X, y = imbalanced_dataset
115+
X_train, X_test, y_train, y_test = train_test_split(
116+
X, y, random_state=42, stratify=y
117+
)
113118
est = BalancedRandomForestClassifier(oob_score=True, random_state=0)
114119

115-
n_samples = X.shape[0]
116-
est.fit(X[: n_samples // 2, :], y[: n_samples // 2])
117-
test_score = est.score(X[n_samples // 2:, :], y[n_samples // 2:])
120+
est.fit(X_train, y_train)
121+
test_score = est.score(X_test, y_test)
118122

119123
assert abs(test_score - est.oob_score_) < 0.1
120124

@@ -176,3 +180,16 @@ def test_balanced_random_forest_pruning(imbalanced_dataset):
176180
n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count
177181

178182
assert n_nodes_no_pruning > n_nodes_pruning
183+
184+
185+
def test_balanced_random_forest_oob_binomial():
186+
# Regression test for #655: check that the oob score is closed to 0.5
187+
# a binomial experiment.
188+
rng = np.random.RandomState(42)
189+
n_samples = 1000
190+
X = np.arange(n_samples).reshape(-1, 1)
191+
y = rng.binomial(1, 0.5, size=n_samples)
192+
193+
erf = BalancedRandomForestClassifier(oob_score=True, random_state=42)
194+
erf.fit(X, y)
195+
assert np.abs(erf.oob_score_ - 0.5) < 0.05

0 commit comments

Comments
 (0)