diff --git a/doc/whats_new/v0.6.rst b/doc/whats_new/v0.6.rst index e5aedaf72..612845e03 100644 --- a/doc/whats_new/v0.6.rst +++ b/doc/whats_new/v0.6.rst @@ -15,6 +15,14 @@ scikit-learn: - :class:`imblearn.under_sampling.ClusterCentroids` - :class:`imblearn.under_sampling.InstanceHardnessThreshold` +Bug fixes +......... + +- :class:`imblearn.under_sampling.InstanceHardnessThreshold` now take into + account the `random_state` and will give deterministic results. In addition, + `cross_val_predict` is used to take advantage of the parallelism. + :pr:`599` by :user:`Shihab Shahriar Khan `. + Maintenance ........... diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index cc3b61634..9edd2ab11 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -12,7 +12,10 @@ from sklearn.base import ClassifierMixin, clone from sklearn.ensemble import RandomForestClassifier +from sklearn.ensemble._base import _set_random_states from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import cross_val_predict +from sklearn.utils import check_random_state from sklearn.utils import _safe_indexing from ..base import BaseUnderSampler @@ -108,7 +111,7 @@ def __init__( self.cv = cv self.n_jobs = n_jobs - def _validate_estimator(self): + def _validate_estimator(self, random_state): """Private function to create the classifier""" if ( @@ -117,6 +120,8 @@ def _validate_estimator(self): and hasattr(self.estimator, "predict_proba") ): self.estimator_ = clone(self.estimator) + _set_random_states(self.estimator_, random_state) + elif self.estimator is None: self.estimator_ = RandomForestClassifier( n_estimators=100, @@ -131,22 +136,18 @@ def _validate_estimator(self): ) def _fit_resample(self, X, y): - self._validate_estimator() + random_state = check_random_state(self.random_state) + self._validate_estimator(random_state) target_stats = Counter(y) - skf = StratifiedKFold(n_splits=self.cv, shuffle=False).split(X, y) - probabilities = np.zeros(y.shape[0], dtype=float) - - for train_index, test_index in skf: - X_train = _safe_indexing(X, train_index) - X_test = _safe_indexing(X, test_index) - y_train = _safe_indexing(y, train_index) - y_test = _safe_indexing(y, test_index) - - self.estimator_.fit(X_train, y_train) - - probs = self.estimator_.predict_proba(X_test) - probabilities[test_index] = probs[range(len(y_test)), y_test] + skf = StratifiedKFold( + n_splits=self.cv, shuffle=True, random_state=random_state, + ) + probabilities = cross_val_predict( + self.estimator_, X, y, cv=skf, n_jobs=self.n_jobs, + method='predict_proba' + ) + probabilities = probabilities[range(len(y)), y] idx_under = np.empty((0,), dtype=int) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py index f720e7b38..6f0cf51f4 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py @@ -6,7 +6,10 @@ import pytest import numpy as np + from sklearn.ensemble import GradientBoostingClassifier +from sklearn.ensemble import RandomForestClassifier +from sklearn.utils._testing import assert_array_equal from imblearn.under_sampling import InstanceHardnessThreshold @@ -76,3 +79,16 @@ def test_iht_fit_resample_wrong_class_obj(): iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED) with pytest.raises(ValueError, match="Invalid parameter `estimator`"): iht.fit_resample(X, Y) + + +def test_iht_reproducibility(): + from sklearn.datasets import load_digits + X_digits, y_digits = load_digits(return_X_y=True) + idx_sampled = [] + for seed in range(5): + est = RandomForestClassifier(n_estimators=10, random_state=seed) + iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED) + iht.fit_resample(X_digits, y_digits) + idx_sampled.append(iht.sample_indices_.copy()) + for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]): + assert_array_equal(idx_1, idx_2)