Skip to content

Commit 9b31677

Browse files
Shihab-Shahriarglemaitre
authored andcommitted
FIX reproducibility and parallelization of InstanceHardnessThreshold (#599)
1 parent 153f6e0 commit 9b31677

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

doc/whats_new/v0.6.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ scikit-learn:
1515
- :class:`imblearn.under_sampling.ClusterCentroids`
1616
- :class:`imblearn.under_sampling.InstanceHardnessThreshold`
1717

18+
Bug fixes
19+
.........
20+
21+
- :class:`imblearn.under_sampling.InstanceHardnessThreshold` now take into
22+
account the `random_state` and will give deterministic results. In addition,
23+
`cross_val_predict` is used to take advantage of the parallelism.
24+
:pr:`599` by :user:`Shihab Shahriar Khan <Shihab-Shahriar>`.
25+
1826
Maintenance
1927
...........
2028

imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313
from sklearn.base import ClassifierMixin, clone
1414
from sklearn.ensemble import RandomForestClassifier
15+
from sklearn.ensemble._base import _set_random_states
1516
from sklearn.model_selection import StratifiedKFold
17+
from sklearn.model_selection import cross_val_predict
18+
from sklearn.utils import check_random_state
1619
from sklearn.utils import _safe_indexing
1720

1821
from ..base import BaseUnderSampler
@@ -108,7 +111,7 @@ def __init__(
108111
self.cv = cv
109112
self.n_jobs = n_jobs
110113

111-
def _validate_estimator(self):
114+
def _validate_estimator(self, random_state):
112115
"""Private function to create the classifier"""
113116

114117
if (
@@ -117,6 +120,8 @@ def _validate_estimator(self):
117120
and hasattr(self.estimator, "predict_proba")
118121
):
119122
self.estimator_ = clone(self.estimator)
123+
_set_random_states(self.estimator_, random_state)
124+
120125
elif self.estimator is None:
121126
self.estimator_ = RandomForestClassifier(
122127
n_estimators=100,
@@ -131,22 +136,18 @@ def _validate_estimator(self):
131136
)
132137

133138
def _fit_resample(self, X, y):
134-
self._validate_estimator()
139+
random_state = check_random_state(self.random_state)
140+
self._validate_estimator(random_state)
135141

136142
target_stats = Counter(y)
137-
skf = StratifiedKFold(n_splits=self.cv, shuffle=False).split(X, y)
138-
probabilities = np.zeros(y.shape[0], dtype=float)
139-
140-
for train_index, test_index in skf:
141-
X_train = _safe_indexing(X, train_index)
142-
X_test = _safe_indexing(X, test_index)
143-
y_train = _safe_indexing(y, train_index)
144-
y_test = _safe_indexing(y, test_index)
145-
146-
self.estimator_.fit(X_train, y_train)
147-
148-
probs = self.estimator_.predict_proba(X_test)
149-
probabilities[test_index] = probs[range(len(y_test)), y_test]
143+
skf = StratifiedKFold(
144+
n_splits=self.cv, shuffle=True, random_state=random_state,
145+
)
146+
probabilities = cross_val_predict(
147+
self.estimator_, X, y, cv=skf, n_jobs=self.n_jobs,
148+
method='predict_proba'
149+
)
150+
probabilities = probabilities[range(len(y)), y]
150151

151152
idx_under = np.empty((0,), dtype=int)
152153

imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
import pytest
77
import numpy as np
88

9+
910
from sklearn.ensemble import GradientBoostingClassifier
11+
from sklearn.ensemble import RandomForestClassifier
12+
from sklearn.utils._testing import assert_array_equal
1013

1114
from imblearn.under_sampling import InstanceHardnessThreshold
1215

@@ -76,3 +79,16 @@ def test_iht_fit_resample_wrong_class_obj():
7679
iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
7780
with pytest.raises(ValueError, match="Invalid parameter `estimator`"):
7881
iht.fit_resample(X, Y)
82+
83+
84+
def test_iht_reproducibility():
85+
from sklearn.datasets import load_digits
86+
X_digits, y_digits = load_digits(return_X_y=True)
87+
idx_sampled = []
88+
for seed in range(5):
89+
est = RandomForestClassifier(n_estimators=10, random_state=seed)
90+
iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
91+
iht.fit_resample(X_digits, y_digits)
92+
idx_sampled.append(iht.sample_indices_.copy())
93+
for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]):
94+
assert_array_equal(idx_1, idx_2)

0 commit comments

Comments
 (0)