|
19 | 19 | from sklearn.ensemble._base import _set_random_states
|
20 | 20 | from sklearn.ensemble._forest import _get_n_samples_bootstrap
|
21 | 21 | from sklearn.ensemble._forest import _parallel_build_trees
|
| 22 | +from sklearn.ensemble._forest import _generate_unsampled_indices |
22 | 23 | from sklearn.exceptions import DataConversionWarning
|
23 | 24 | from sklearn.tree import DecisionTreeClassifier
|
24 | 25 | from sklearn.utils import check_array
|
@@ -545,5 +546,65 @@ def fit(self, X, y, sample_weight=None):
|
545 | 546 |
|
546 | 547 | return self
|
547 | 548 |
|
| 549 | + def _set_oob_score(self, X, y): |
| 550 | + """Compute out-of-bag score.""" |
| 551 | + X = check_array(X, dtype=DTYPE, accept_sparse='csr') |
| 552 | + |
| 553 | + n_classes_ = self.n_classes_ |
| 554 | + n_samples = y.shape[0] |
| 555 | + |
| 556 | + oob_decision_function = [] |
| 557 | + oob_score = 0.0 |
| 558 | + predictions = [np.zeros((n_samples, n_classes_[k])) |
| 559 | + for k in range(self.n_outputs_)] |
| 560 | + |
| 561 | + for sampler, estimator in zip(self.samplers_, self.estimators_): |
| 562 | + X_resample = X[sampler.sample_indices_] |
| 563 | + y_resample = y[sampler.sample_indices_] |
| 564 | + |
| 565 | + n_sample_subset = y_resample.shape[0] |
| 566 | + n_samples_bootstrap = _get_n_samples_bootstrap( |
| 567 | + n_sample_subset, self.max_samples |
| 568 | + ) |
| 569 | + |
| 570 | + unsampled_indices = _generate_unsampled_indices( |
| 571 | + estimator.random_state, n_sample_subset, n_samples_bootstrap |
| 572 | + ) |
| 573 | + p_estimator = estimator.predict_proba( |
| 574 | + X_resample[unsampled_indices, :], check_input=False |
| 575 | + ) |
| 576 | + |
| 577 | + if self.n_outputs_ == 1: |
| 578 | + p_estimator = [p_estimator] |
| 579 | + |
| 580 | + for k in range(self.n_outputs_): |
| 581 | + indices = sampler.sample_indices_[unsampled_indices] |
| 582 | + predictions[k][indices, :] += p_estimator[k] |
| 583 | + |
| 584 | + for k in range(self.n_outputs_): |
| 585 | + if (predictions[k].sum(axis=1) == 0).any(): |
| 586 | + warn("Some inputs do not have OOB scores. " |
| 587 | + "This probably means too few trees were used " |
| 588 | + "to compute any reliable oob estimates.") |
| 589 | + |
| 590 | + with np.errstate(invalid="ignore", divide="ignore"): |
| 591 | + # with the resampling, we are likely to have rows not included |
| 592 | + # for the OOB score leading to division by zero |
| 593 | + decision = (predictions[k] / |
| 594 | + predictions[k].sum(axis=1)[:, np.newaxis]) |
| 595 | + mask_scores = np.isnan(np.sum(decision, axis=1)) |
| 596 | + oob_decision_function.append(decision) |
| 597 | + oob_score += np.mean( |
| 598 | + y[~mask_scores, k] == np.argmax(predictions[k][~mask_scores], |
| 599 | + axis=1), |
| 600 | + axis=0) |
| 601 | + |
| 602 | + if self.n_outputs_ == 1: |
| 603 | + self.oob_decision_function_ = oob_decision_function[0] |
| 604 | + else: |
| 605 | + self.oob_decision_function_ = oob_decision_function |
| 606 | + |
| 607 | + self.oob_score_ = oob_score / self.n_outputs_ |
| 608 | + |
548 | 609 | def _more_tags(self):
|
549 | 610 | return {"multioutput": False}
|
0 commit comments