Skip to content

Commit 56fb7d2

Browse files
chkoarglemaitre
authored andcommitted
[MRG] MAINT explicit fail messages on non supported targets (#544)
1 parent f30d0cf commit 56fb7d2

File tree

7 files changed

+33
-11
lines changed

7 files changed

+33
-11
lines changed

imblearn/ensemble/_bagging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..pipeline import Pipeline
1616
from ..under_sampling import RandomUnderSampler
1717
from ..under_sampling.base import BaseUnderSampler
18-
from ..utils import Substitution
18+
from ..utils import Substitution, check_target_type
1919
from ..utils._docstring import _random_state_docstring
2020

2121

@@ -240,6 +240,7 @@ def fit(self, X, y):
240240
self : object
241241
Returns self.
242242
"""
243+
check_target_type(y)
243244
# RandomUnderSampler is not supporting sample_weight. We need to pass
244245
# None.
245246
return self._fit(X, y, self.max_samples, sample_weight=None)

imblearn/ensemble/_easy_ensemble.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .base import BaseEnsembleSampler
1818
from ..under_sampling import RandomUnderSampler
1919
from ..under_sampling.base import BaseUnderSampler
20-
from ..utils import Substitution
20+
from ..utils import Substitution, check_target_type
2121
from ..utils._docstring import _random_state_docstring
2222
from ..pipeline import Pipeline
2323

@@ -290,6 +290,7 @@ def fit(self, X, y):
290290
self : object
291291
Returns self.
292292
"""
293+
check_target_type(y)
293294
# RandomUnderSampler is not supporting sample_weight. We need to pass
294295
# None.
295296
return self._fit(X, y, self.max_samples, sample_weight=None)

imblearn/ensemble/_weight_boosting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..under_sampling.base import BaseUnderSampler
1313
from ..under_sampling import RandomUnderSampler
1414
from ..pipeline import make_pipeline
15-
from ..utils import Substitution
15+
from ..utils import Substitution, check_target_type
1616
from ..utils._docstring import _random_state_docstring
1717

1818

@@ -146,6 +146,7 @@ def fit(self, X, y, sample_weight=None):
146146
Returns self.
147147
148148
"""
149+
check_target_type(y)
149150
self.samplers_ = []
150151
self.pipelines_ = []
151152
super().fit(X, y, sample_weight)

imblearn/ensemble/tests/test_weight_boosting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def imbalanced_dataset():
2323
[({"n_estimators": 'whatever'}, "n_estimators must be an integer"),
2424
({"n_estimators": -100}, "n_estimators must be greater than zero")]
2525
)
26-
def test_balanced_random_forest_error(imbalanced_dataset, boosting_params,
27-
err_msg):
26+
def test_rusboost_error(imbalanced_dataset, boosting_params, err_msg):
2827
rusboost = RUSBoostClassifier(**boosting_params)
2928
with pytest.raises(ValueError, match=err_msg):
3029
rusboost.fit(*imbalanced_dataset)

imblearn/over_sampling/tests/test_smote_nc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def test_smotenc_check_target_type():
131131
smote.fit_resample(X, y)
132132
rng = np.random.RandomState(42)
133133
y = rng.randint(2, size=(20, 3))
134-
with pytest.raises(ValueError, match="'y' should encode the multiclass"):
134+
msg = "Multilabel and multioutput targets are not supported."
135+
with pytest.raises(ValueError, match=msg):
135136
smote.fit_resample(X, y)
136137

137138

imblearn/utils/_validation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def check_target_type(y, indicate_one_vs_all=False):
8888
if type_y == 'multilabel-indicator':
8989
if np.any(y.sum(axis=1) > 1):
9090
raise ValueError(
91-
"When 'y' corresponds to '{}', 'y' should encode the "
92-
"multiclass (a single 1 by row).".format(type_y))
91+
"Imbalanced-learn currently supports binary, multiclass and "
92+
"binarized encoded multiclasss targets. Multilabel and "
93+
"multioutput targets are not supported.")
9394
y = y.argmax(axis=1)
9495

9596
return (y, type_y == 'multilabel-indicator') if indicate_one_vs_all else y

imblearn/utils/estimator_checks.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from scipy import sparse
1818

1919
from sklearn.base import clone
20-
from sklearn.datasets import make_classification
20+
from sklearn.datasets import make_classification, make_multilabel_classification # noqa
2121
from sklearn.cluster import KMeans
2222
from sklearn.preprocessing import label_binarize
2323
from sklearn.utils.estimator_checks import check_estimator \
@@ -27,6 +27,7 @@
2727
from sklearn.utils.testing import set_random_state
2828
from sklearn.utils.multiclass import type_of_target
2929

30+
from imblearn.base import BaseSampler
3031
from imblearn.over_sampling.base import BaseOverSampler
3132
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
3233
from imblearn.ensemble.base import BaseEnsembleSampler
@@ -54,10 +55,18 @@ def _yield_sampler_checks(name, Estimator):
5455
yield check_samplers_sample_indices
5556

5657

58+
def _yield_classifier_checks(name, Estimator):
59+
yield check_classifier_on_multilabel_or_multioutput_targets
60+
61+
5762
def _yield_all_checks(name, estimator):
5863
# trigger our checks if this is a SamplerMixin
5964
if hasattr(estimator, 'fit_resample'):
60-
yield from _yield_sampler_checks(name, estimator)
65+
for check in _yield_sampler_checks(name, estimator):
66+
yield check
67+
if hasattr(estimator, 'predict'):
68+
for check in _yield_classifier_checks(name, estimator):
69+
yield check
6170

6271

6372
def check_estimator(Estimator, run_sampler_tests=True):
@@ -99,7 +108,8 @@ def check_target_type(name, Estimator):
99108
# if the target is multilabel then we should raise an error
100109
rng = np.random.RandomState(42)
101110
y = rng.randint(2, size=(20, 3))
102-
with pytest.raises(ValueError, match="'y' should encode the multiclass"):
111+
msg = "Multilabel and multioutput targets are not supported."
112+
with pytest.raises(ValueError, match=msg):
103113
estimator.fit_resample(X, y)
104114

105115

@@ -342,3 +352,11 @@ def check_samplers_sample_indices(name, Sampler):
342352
assert hasattr(sampler, 'sample_indices_') is sample_indices
343353
else:
344354
assert not hasattr(sampler, 'sample_indices_')
355+
356+
357+
def check_classifier_on_multilabel_or_multioutput_targets(name, Estimator):
358+
estimator = Estimator()
359+
X, y = make_multilabel_classification(n_samples=30)
360+
msg = "Multilabel and multioutput targets are not supported."
361+
with pytest.raises(ValueError, match=msg):
362+
estimator.fit(X, y)

0 commit comments

Comments
 (0)