Skip to content

Commit b4faf5f

Browse files
gmogolglemaitre
andauthored
FIX InstanceHardnessThreshold accepts classifier included in a Pipeline (#1049)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 4c35c0f commit b4faf5f

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

doc/whats_new/v0.12.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
.. _changes_0_12:
22

3+
Version 0.12.1
4+
==============
5+
6+
**In progress**
7+
8+
Changelog
9+
---------
10+
11+
Bug fixes
12+
.........
13+
14+
- Fix a bug in :class:`~imblearn.under_sampling.InstanceHardnessThreshold` where
15+
`estimator` could not be a :class:`~sklearn.pipeline.Pipeline` object.
16+
:pr:`1049` by :user:`Gonenc Mogol <gmogol>`.
17+
318
Version 0.12.0
419
==============
520

imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections import Counter
1111

1212
import numpy as np
13-
from sklearn.base import ClassifierMixin, clone
13+
from sklearn.base import clone, is_classifier
1414
from sklearn.ensemble import RandomForestClassifier
1515
from sklearn.ensemble._base import _set_random_states
1616
from sklearn.model_selection import StratifiedKFold, cross_val_predict
@@ -140,7 +140,7 @@ def _validate_estimator(self, random_state):
140140

141141
if (
142142
self.estimator is not None
143-
and isinstance(self.estimator, ClassifierMixin)
143+
and is_classifier(self.estimator)
144144
and hasattr(self.estimator, "predict_proba")
145145
):
146146
self.estimator_ = clone(self.estimator)

imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
88
from sklearn.naive_bayes import GaussianNB as NB
9+
from sklearn.pipeline import make_pipeline
910
from sklearn.utils._testing import assert_array_equal
1011

1112
from imblearn.under_sampling import InstanceHardnessThreshold
@@ -93,3 +94,19 @@ def test_iht_fit_resample_default_estimator():
9394
assert isinstance(iht.estimator_, RandomForestClassifier)
9495
assert X_resampled.shape == (12, 2)
9596
assert y_resampled.shape == (12,)
97+
98+
99+
def test_iht_estimator_pipeline():
100+
"""Check that we can pass a pipeline containing a classifier.
101+
102+
Checking if we have a classifier should not be based on inheriting from
103+
`ClassifierMixin`.
104+
105+
Non-regression test for:
106+
https://github.com/scikit-learn-contrib/imbalanced-learn/pull/1049
107+
"""
108+
model = make_pipeline(GradientBoostingClassifier(random_state=RND_SEED))
109+
iht = InstanceHardnessThreshold(estimator=model, random_state=RND_SEED)
110+
X_resampled, y_resampled = iht.fit_resample(X, Y)
111+
assert X_resampled.shape == (12, 2)
112+
assert y_resampled.shape == (12,)

0 commit comments

Comments
 (0)