diff --git a/doc/whats_new/v0.6.rst b/doc/whats_new/v0.6.rst index bb6c97196..91403473a 100644 --- a/doc/whats_new/v0.6.rst +++ b/doc/whats_new/v0.6.rst @@ -28,6 +28,10 @@ Bug fixes `cross_val_predict` is used to take advantage of the parallelism. :pr:`599` by :user:`Shihab Shahriar Khan `. +- Fix a bug in :class:`imblearn.ensemble.BalancedRandomForestClassifier` + leading to a wrong computation of the OOB score. + :pr:`656` by :user:`Guillaume Lemaitre `. + Maintenance ........... diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py index c23ca96e6..ba79b0105 100644 --- a/imblearn/ensemble/_forest.py +++ b/imblearn/ensemble/_forest.py @@ -19,6 +19,7 @@ from sklearn.ensemble._base import _set_random_states from sklearn.ensemble._forest import _get_n_samples_bootstrap from sklearn.ensemble._forest import _parallel_build_trees +from sklearn.ensemble._forest import _generate_unsampled_indices from sklearn.exceptions import DataConversionWarning from sklearn.tree import DecisionTreeClassifier from sklearn.utils import check_array @@ -545,5 +546,65 @@ def fit(self, X, y, sample_weight=None): return self + def _set_oob_score(self, X, y): + """Compute out-of-bag score.""" + X = check_array(X, dtype=DTYPE, accept_sparse='csr') + + n_classes_ = self.n_classes_ + n_samples = y.shape[0] + + oob_decision_function = [] + oob_score = 0.0 + predictions = [np.zeros((n_samples, n_classes_[k])) + for k in range(self.n_outputs_)] + + for sampler, estimator in zip(self.samplers_, self.estimators_): + X_resample = X[sampler.sample_indices_] + y_resample = y[sampler.sample_indices_] + + n_sample_subset = y_resample.shape[0] + n_samples_bootstrap = _get_n_samples_bootstrap( + n_sample_subset, self.max_samples + ) + + unsampled_indices = _generate_unsampled_indices( + estimator.random_state, n_sample_subset, n_samples_bootstrap + ) + p_estimator = estimator.predict_proba( + X_resample[unsampled_indices, :], check_input=False + ) + + if self.n_outputs_ == 1: + p_estimator = [p_estimator] + + for k in range(self.n_outputs_): + indices = sampler.sample_indices_[unsampled_indices] + predictions[k][indices, :] += p_estimator[k] + + for k in range(self.n_outputs_): + if (predictions[k].sum(axis=1) == 0).any(): + warn("Some inputs do not have OOB scores. " + "This probably means too few trees were used " + "to compute any reliable oob estimates.") + + with np.errstate(invalid="ignore", divide="ignore"): + # with the resampling, we are likely to have rows not included + # for the OOB score leading to division by zero + decision = (predictions[k] / + predictions[k].sum(axis=1)[:, np.newaxis]) + mask_scores = np.isnan(np.sum(decision, axis=1)) + oob_decision_function.append(decision) + oob_score += np.mean( + y[~mask_scores, k] == np.argmax(predictions[k][~mask_scores], + axis=1), + axis=0) + + if self.n_outputs_ == 1: + self.oob_decision_function_ = oob_decision_function[0] + else: + self.oob_decision_function_ = oob_decision_function + + self.oob_score_ = oob_score / self.n_outputs_ + def _more_tags(self): return {"multioutput": False} diff --git a/imblearn/ensemble/tests/test_forest.py b/imblearn/ensemble/tests/test_forest.py index 0c8c615f8..533cec425 100644 --- a/imblearn/ensemble/tests/test_forest.py +++ b/imblearn/ensemble/tests/test_forest.py @@ -4,6 +4,7 @@ from sklearn.datasets import make_classification from sklearn.model_selection import GridSearchCV +from sklearn.model_selection import train_test_split from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_equal @@ -108,13 +109,16 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset): brf.fit(X, y, sample_weight) +@pytest.mark.filterwarnings("ignore:Some inputs do not have OOB scores") def test_balanced_random_forest_oob(imbalanced_dataset): X, y = imbalanced_dataset + X_train, X_test, y_train, y_test = train_test_split( + X, y, random_state=42, stratify=y + ) est = BalancedRandomForestClassifier(oob_score=True, random_state=0) - n_samples = X.shape[0] - est.fit(X[: n_samples // 2, :], y[: n_samples // 2]) - test_score = est.score(X[n_samples // 2:, :], y[n_samples // 2:]) + est.fit(X_train, y_train) + test_score = est.score(X_test, y_test) assert abs(test_score - est.oob_score_) < 0.1 @@ -176,3 +180,16 @@ def test_balanced_random_forest_pruning(imbalanced_dataset): n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count assert n_nodes_no_pruning > n_nodes_pruning + + +def test_balanced_random_forest_oob_binomial(): + # Regression test for #655: check that the oob score is closed to 0.5 + # a binomial experiment. + rng = np.random.RandomState(42) + n_samples = 1000 + X = np.arange(n_samples).reshape(-1, 1) + y = rng.binomial(1, 0.5, size=n_samples) + + erf = BalancedRandomForestClassifier(oob_score=True, random_state=42) + erf.fit(X, y) + assert np.abs(erf.oob_score_ - 0.5) < 0.05