Skip to content

FIX reproducibility and parallelization of InstanceHardnessThreshold #599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
8 changes: 8 additions & 0 deletions doc/whats_new/v0.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ scikit-learn:
- :class:`imblearn.under_sampling.ClusterCentroids`
- :class:`imblearn.under_sampling.InstanceHardnessThreshold`

Bug fixes
.........

- :class:`imblearn.under_sampling.InstanceHardnessThreshold` now take into
account the `random_state` and will give deterministic results. In addition,
`cross_val_predict` is used to take advantage of the parallelism.
:pr:`599` by :user:`Shihab Shahriar Khan <Shihab-Shahriar>`.

Maintenance
...........

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

from sklearn.base import ClassifierMixin, clone
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble._base import _set_random_states
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_predict
from sklearn.utils import check_random_state
from sklearn.utils import _safe_indexing

from ..base import BaseUnderSampler
Expand Down Expand Up @@ -108,7 +111,7 @@ def __init__(
self.cv = cv
self.n_jobs = n_jobs

def _validate_estimator(self):
def _validate_estimator(self, random_state):
"""Private function to create the classifier"""

if (
Expand All @@ -117,6 +120,8 @@ def _validate_estimator(self):
and hasattr(self.estimator, "predict_proba")
):
self.estimator_ = clone(self.estimator)
_set_random_states(self.estimator_, random_state)

elif self.estimator is None:
self.estimator_ = RandomForestClassifier(
n_estimators=100,
Expand All @@ -131,22 +136,18 @@ def _validate_estimator(self):
)

def _fit_resample(self, X, y):
self._validate_estimator()
random_state = check_random_state(self.random_state)
self._validate_estimator(random_state)

target_stats = Counter(y)
skf = StratifiedKFold(n_splits=self.cv, shuffle=False).split(X, y)
probabilities = np.zeros(y.shape[0], dtype=float)

for train_index, test_index in skf:
X_train = _safe_indexing(X, train_index)
X_test = _safe_indexing(X, test_index)
y_train = _safe_indexing(y, train_index)
y_test = _safe_indexing(y, test_index)

self.estimator_.fit(X_train, y_train)

probs = self.estimator_.predict_proba(X_test)
probabilities[test_index] = probs[range(len(y_test)), y_test]
skf = StratifiedKFold(
n_splits=self.cv, shuffle=True, random_state=random_state,
)
probabilities = cross_val_predict(
self.estimator_, X, y, cv=skf, n_jobs=self.n_jobs,
method='predict_proba'
)
probabilities = probabilities[range(len(y)), y]

idx_under = np.empty((0,), dtype=int)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import pytest
import numpy as np


from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils._testing import assert_array_equal

from imblearn.under_sampling import InstanceHardnessThreshold

Expand Down Expand Up @@ -76,3 +79,16 @@ def test_iht_fit_resample_wrong_class_obj():
iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
with pytest.raises(ValueError, match="Invalid parameter `estimator`"):
iht.fit_resample(X, Y)


def test_iht_reproducibility():
from sklearn.datasets import load_digits
X_digits, y_digits = load_digits(return_X_y=True)
idx_sampled = []
for seed in range(5):
est = RandomForestClassifier(n_estimators=10, random_state=seed)
iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
iht.fit_resample(X_digits, y_digits)
idx_sampled.append(iht.sample_indices_.copy())
for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]):
assert_array_equal(idx_1, idx_2)