From 6e37a6081a474cb9b0785c75a1badfb68b128816 Mon Sep 17 00:00:00 2001 From: Shihab Shahriar Khan Date: Sat, 7 Sep 2019 19:28:20 +0600 Subject: [PATCH 1/7] FIX reproducibility and parallelization of InstanceHardnessThreshold --- .../_instance_hardness_threshold.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index 5991a2785..6795a381d 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -12,7 +12,7 @@ from sklearn.base import ClassifierMixin, clone from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import StratifiedKFold,cross_val_predict from sklearn.utils import safe_indexing from ..base import BaseUnderSampler @@ -126,6 +126,7 @@ def _validate_estimator(self): isinstance(self.estimator, ClassifierMixin) and hasattr(self.estimator, 'predict_proba')): self.estimator_ = clone(self.estimator) + self.estimator_.set_params(n_jobs=1,random_state=self.random_state) elif self.estimator is None: self.estimator_ = RandomForestClassifier( n_estimators=100, random_state=self.random_state, @@ -143,19 +144,10 @@ def _fit_resample(self, X, y): target_stats = Counter(y) skf = StratifiedKFold( n_splits=self.cv, shuffle=False, - random_state=self.random_state).split(X, y) - probabilities = np.zeros(y.shape[0], dtype=float) - - for train_index, test_index in skf: - X_train = safe_indexing(X, train_index) - X_test = safe_indexing(X, test_index) - y_train = safe_indexing(y, train_index) - y_test = safe_indexing(y, test_index) - - self.estimator_.fit(X_train, y_train) - - probs = self.estimator_.predict_proba(X_test) - probabilities[test_index] = probs[range(len(y_test)), y_test] + random_state=self.random_state) + probabilities = cross_val_predict(self.estimator_, X, y, cv=skf, + n_jobs=self.n_jobs, method='predict_proba') + probabilities = probabilities[range(len(y)), y] idx_under = np.empty((0, ), dtype=int) From b9a66d7f4dd500d2cf14309600b0ddb6cb196f92 Mon Sep 17 00:00:00 2001 From: Shihab Shahriar Khan Date: Sat, 7 Sep 2019 20:53:07 +0600 Subject: [PATCH 2/7] set n_jobs to 1 for estimators that take this parameter --- .../_instance_hardness_threshold.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index 6795a381d..410873196 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -12,7 +12,7 @@ from sklearn.base import ClassifierMixin, clone from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import StratifiedKFold,cross_val_predict +from sklearn.model_selection import StratifiedKFold, cross_val_predict from sklearn.utils import safe_indexing from ..base import BaseUnderSampler @@ -126,7 +126,9 @@ def _validate_estimator(self): isinstance(self.estimator, ClassifierMixin) and hasattr(self.estimator, 'predict_proba')): self.estimator_ = clone(self.estimator) - self.estimator_.set_params(n_jobs=1,random_state=self.random_state) + self.estimator_.set_params(random_state=self.random_state) + if 'n_jobs' in self.estimator_.get_params().keys(): + self.estimator_.set_params(n_jobs=1) elif self.estimator is None: self.estimator_ = RandomForestClassifier( n_estimators=100, random_state=self.random_state, @@ -146,10 +148,10 @@ def _fit_resample(self, X, y): n_splits=self.cv, shuffle=False, random_state=self.random_state) probabilities = cross_val_predict(self.estimator_, X, y, cv=skf, - n_jobs=self.n_jobs, method='predict_proba') + n_jobs=self.n_jobs, method='predict_proba') probabilities = probabilities[range(len(y)), y] - idx_under = np.empty((0, ), dtype=int) + idx_under = np.empty((0,), dtype=int) for target_class in np.unique(y): if target_class in self.sampling_strategy_.keys(): From 963cdd3da0498b2a16456e6206a266e827fc592d Mon Sep 17 00:00:00 2001 From: Shihab Shahriar Khan Date: Sun, 8 Sep 2019 16:12:05 +0600 Subject: [PATCH 3/7] added reproducibility test and enhanced 'test_iht_init' to use multiple estimators --- .../tests/test_instance_hardness_threshold.py | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py index 68fd15b3c..c8a9cb891 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py @@ -6,8 +6,8 @@ import pytest import numpy as np -from sklearn.utils.testing import assert_array_equal -from sklearn.ensemble import GradientBoostingClassifier +from sklearn.utils.testing import assert_array_equal, all_estimators +from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier from imblearn.under_sampling import InstanceHardnessThreshold @@ -15,10 +15,10 @@ X = np.array([[-0.3879569, 0.6894251], [-0.09322739, 1.28177189], [ -0.77740357, 0.74097941 ], [0.91542919, -0.65453327], [-0.03852113, 0.40910479], [ - -0.43877303, 1.07366684 -], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [ - -0.30126957, -0.66268378 -], [-0.65571327, 0.42412021], [-0.28305528, 0.30284991], + -0.43877303, 1.07366684 + ], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [ + -0.30126957, -0.66268378 + ], [-0.65571327, 0.42412021], [-0.28305528, 0.30284991], [0.20246714, -0.34727125], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [-0.00717161, 0.00318087]]) Y = np.array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0]) @@ -27,11 +27,17 @@ def test_iht_init(): sampling_strategy = 'auto' - iht = InstanceHardnessThreshold( - ESTIMATOR, sampling_strategy=sampling_strategy, random_state=RND_SEED) - - assert iht.sampling_strategy == sampling_strategy - assert iht.random_state == RND_SEED + for name, class_ in all_estimators(): + if not hasattr(class_, 'predict_proba'): + continue + try: + class_() + except TypeError: + continue + iht = InstanceHardnessThreshold( + class_(), sampling_strategy=sampling_strategy, random_state=RND_SEED) + assert iht.sampling_strategy == sampling_strategy + assert iht.random_state == RND_SEED def test_iht_fit_resample(): @@ -41,8 +47,8 @@ def test_iht_fit_resample(): X_gt = np.array([[-0.3879569, 0.6894251], [0.91542919, -0.65453327], [ -0.65571327, 0.42412021 ], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [ - -0.00717161, 0.00318087 - ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], + -0.00717161, 0.00318087 + ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [-0.28305528, 0.30284991]]) y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) @@ -59,8 +65,8 @@ def test_iht_fit_resample_with_indices(): X_gt = np.array([[-0.3879569, 0.6894251], [0.91542919, -0.65453327], [ -0.65571327, 0.42412021 ], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [ - -0.00717161, 0.00318087 - ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], + -0.00717161, 0.00318087 + ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [-0.28305528, 0.30284991]]) y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) @@ -79,8 +85,8 @@ def test_iht_fit_resample_half(): X_gt = np.array([[-0.3879569, 0.6894251], [0.91542919, -0.65453327], [ -0.65571327, 0.42412021 ], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [ - -0.00717161, 0.00318087 - ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], + -0.00717161, 0.00318087 + ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], [-0.03852113, 0.40910479], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [-0.30126957, -0.66268378], [-0.28305528, 0.30284991]]) @@ -97,8 +103,8 @@ def test_iht_fit_resample_class_obj(): X_gt = np.array([[-0.3879569, 0.6894251], [0.91542919, -0.65453327], [ -0.65571327, 0.42412021 ], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [ - -0.00717161, 0.00318087 - ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], + -0.00717161, 0.00318087 + ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [-0.28305528, 0.30284991]]) y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) @@ -112,3 +118,17 @@ def test_iht_fit_resample_wrong_class_obj(): iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED) with pytest.raises(ValueError, match="Invalid parameter `estimator`"): iht.fit_resample(X, Y) + + +def test_iht_reproducibility(): + from sklearn.datasets import load_digits + X_digits, y_digits = load_digits(return_X_y=True) + idx_sampled = [] + for seed in range(5): + est = RandomForestClassifier(n_estimators=10, random_state=seed) + iht = InstanceHardnessThreshold( + estimator=est, + random_state=RND_SEED, return_indices=True) + idx_sampled.append(iht.fit_resample(X_digits, y_digits)[2]) + for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]): + assert_array_equal(idx_1, idx_2) From 152524f4efa9292650e8ce0442c734db3632b882 Mon Sep 17 00:00:00 2001 From: Shihab Shahriar Khan Date: Wed, 11 Sep 2019 14:20:12 +0600 Subject: [PATCH 4/7] Modified test_iht_init back to original version, and random_state is now set using _set_random_state --- .../_instance_hardness_threshold.py | 21 +++++--- .../tests/test_instance_hardness_threshold.py | 49 ++++++++----------- 2 files changed, 34 insertions(+), 36 deletions(-) diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index 410873196..0cdf6510e 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -12,8 +12,9 @@ from sklearn.base import ClassifierMixin, clone from sklearn.ensemble import RandomForestClassifier +from sklearn.ensemble.base import _set_random_states from sklearn.model_selection import StratifiedKFold, cross_val_predict -from sklearn.utils import safe_indexing +from sklearn.utils import safe_indexing, check_random_state from ..base import BaseUnderSampler from ...utils import Substitution @@ -119,16 +120,18 @@ def __init__(self, self.cv = cv self.n_jobs = n_jobs - def _validate_estimator(self): + def _validate_estimator(self, random_state): """Private function to create the classifier""" if (self.estimator is not None and isinstance(self.estimator, ClassifierMixin) and hasattr(self.estimator, 'predict_proba')): self.estimator_ = clone(self.estimator) - self.estimator_.set_params(random_state=self.random_state) - if 'n_jobs' in self.estimator_.get_params().keys(): + _set_random_states(self.estimator_, random_state) + try: self.estimator_.set_params(n_jobs=1) + except ValueError: + pass elif self.estimator is None: self.estimator_ = RandomForestClassifier( n_estimators=100, random_state=self.random_state, @@ -141,14 +144,16 @@ def _fit_resample(self, X, y): if self.return_indices: deprecate_parameter(self, '0.4', 'return_indices', 'sample_indices_') - self._validate_estimator() + random_state = check_random_state(self.random_state) + self._validate_estimator(random_state) target_stats = Counter(y) skf = StratifiedKFold( n_splits=self.cv, shuffle=False, - random_state=self.random_state) - probabilities = cross_val_predict(self.estimator_, X, y, cv=skf, - n_jobs=self.n_jobs, method='predict_proba') + random_state=random_state) + probabilities = cross_val_predict( + self.estimator_, X, y, cv=skf, + n_jobs=self.n_jobs, method='predict_proba') probabilities = probabilities[range(len(y)), y] idx_under = np.empty((0,), dtype=int) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py index c8a9cb891..43192dae9 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py @@ -6,7 +6,7 @@ import pytest import numpy as np -from sklearn.utils.testing import assert_array_equal, all_estimators +from sklearn.utils.testing import assert_array_equal from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier from imblearn.under_sampling import InstanceHardnessThreshold @@ -15,10 +15,10 @@ X = np.array([[-0.3879569, 0.6894251], [-0.09322739, 1.28177189], [ -0.77740357, 0.74097941 ], [0.91542919, -0.65453327], [-0.03852113, 0.40910479], [ - -0.43877303, 1.07366684 - ], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [ - -0.30126957, -0.66268378 - ], [-0.65571327, 0.42412021], [-0.28305528, 0.30284991], + -0.43877303, 1.07366684 +], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [ + -0.30126957, -0.66268378 +], [-0.65571327, 0.42412021], [-0.28305528, 0.30284991], [0.20246714, -0.34727125], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [-0.00717161, 0.00318087]]) Y = np.array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0]) @@ -27,17 +27,11 @@ def test_iht_init(): sampling_strategy = 'auto' - for name, class_ in all_estimators(): - if not hasattr(class_, 'predict_proba'): - continue - try: - class_() - except TypeError: - continue - iht = InstanceHardnessThreshold( - class_(), sampling_strategy=sampling_strategy, random_state=RND_SEED) - assert iht.sampling_strategy == sampling_strategy - assert iht.random_state == RND_SEED + iht = InstanceHardnessThreshold( + ESTIMATOR, sampling_strategy=sampling_strategy, random_state=RND_SEED) + + assert iht.sampling_strategy == sampling_strategy + assert iht.random_state == RND_SEED def test_iht_fit_resample(): @@ -47,8 +41,8 @@ def test_iht_fit_resample(): X_gt = np.array([[-0.3879569, 0.6894251], [0.91542919, -0.65453327], [ -0.65571327, 0.42412021 ], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [ - -0.00717161, 0.00318087 - ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], + -0.00717161, 0.00318087 + ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [-0.28305528, 0.30284991]]) y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) @@ -65,8 +59,8 @@ def test_iht_fit_resample_with_indices(): X_gt = np.array([[-0.3879569, 0.6894251], [0.91542919, -0.65453327], [ -0.65571327, 0.42412021 ], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [ - -0.00717161, 0.00318087 - ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], + -0.00717161, 0.00318087 + ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [-0.28305528, 0.30284991]]) y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) @@ -85,8 +79,8 @@ def test_iht_fit_resample_half(): X_gt = np.array([[-0.3879569, 0.6894251], [0.91542919, -0.65453327], [ -0.65571327, 0.42412021 ], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [ - -0.00717161, 0.00318087 - ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], + -0.00717161, 0.00318087 + ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], [-0.03852113, 0.40910479], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [-0.30126957, -0.66268378], [-0.28305528, 0.30284991]]) @@ -103,8 +97,8 @@ def test_iht_fit_resample_class_obj(): X_gt = np.array([[-0.3879569, 0.6894251], [0.91542919, -0.65453327], [ -0.65571327, 0.42412021 ], [1.06446472, -1.09279772], [0.30543283, -0.02589502], [ - -0.00717161, 0.00318087 - ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], + -0.00717161, 0.00318087 + ], [-0.09322739, 1.28177189], [-0.77740357, 0.74097941], [-0.43877303, 1.07366684], [-0.85795321, 0.82980738], [-0.18430329, 0.52328473], [-0.28305528, 0.30284991]]) y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) @@ -126,9 +120,8 @@ def test_iht_reproducibility(): idx_sampled = [] for seed in range(5): est = RandomForestClassifier(n_estimators=10, random_state=seed) - iht = InstanceHardnessThreshold( - estimator=est, - random_state=RND_SEED, return_indices=True) - idx_sampled.append(iht.fit_resample(X_digits, y_digits)[2]) + iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED) + iht.fit_resample(X_digits, y_digits) + idx_sampled.append(iht.sample_indices_.copy()) for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]): assert_array_equal(idx_1, idx_2) From edeee850d40831377f2f08c30e76456f2965ef26 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Nov 2019 12:10:05 +0100 Subject: [PATCH 5/7] fix after merging master --- .../_prototype_selection/_instance_hardness_threshold.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index 7ade7ac40..2523f8db2 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -139,11 +139,12 @@ def _validate_estimator(self, random_state): ) def _fit_resample(self, X, y): - self._validate_estimator() + random_state = check_random_state(self.random_state) + self._validate_estimator(random_state) target_stats = Counter(y) skf = StratifiedKFold( - n_splits=self.cv, shuffle=False, random_state=random_state + n_splits=self.cv, shuffle=True, random_state=random_state, ) probabilities = cross_val_predict( self.estimator_, X, y, cv=skf, n_jobs=self.n_jobs, From 81a7622e45c94cd221d2dfd87f673fc79e711a42 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Nov 2019 12:13:06 +0100 Subject: [PATCH 6/7] Let joblib the nested parallelism and over-subscription issue --- .../_prototype_selection/_instance_hardness_threshold.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index 2523f8db2..9edd2ab11 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -121,10 +121,7 @@ def _validate_estimator(self, random_state): ): self.estimator_ = clone(self.estimator) _set_random_states(self.estimator_, random_state) - try: - self.estimator_.set_params(n_jobs=1) - except ValueError: - pass + elif self.estimator is None: self.estimator_ = RandomForestClassifier( n_estimators=100, From f6ed9aa6b2a67f32866168a875a87d7b7fcd58ce Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Nov 2019 12:19:53 +0100 Subject: [PATCH 7/7] DOC add entry in whats new --- doc/whats_new/v0.6.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/whats_new/v0.6.rst b/doc/whats_new/v0.6.rst index e5aedaf72..612845e03 100644 --- a/doc/whats_new/v0.6.rst +++ b/doc/whats_new/v0.6.rst @@ -15,6 +15,14 @@ scikit-learn: - :class:`imblearn.under_sampling.ClusterCentroids` - :class:`imblearn.under_sampling.InstanceHardnessThreshold` +Bug fixes +......... + +- :class:`imblearn.under_sampling.InstanceHardnessThreshold` now take into + account the `random_state` and will give deterministic results. In addition, + `cross_val_predict` is used to take advantage of the parallelism. + :pr:`599` by :user:`Shihab Shahriar Khan `. + Maintenance ...........