diff --git a/imblearn/ensemble/_bagging.py b/imblearn/ensemble/_bagging.py index 5fbb2f7b4..89e764330 100644 --- a/imblearn/ensemble/_bagging.py +++ b/imblearn/ensemble/_bagging.py @@ -15,7 +15,7 @@ from ..pipeline import Pipeline from ..under_sampling import RandomUnderSampler from ..under_sampling.base import BaseUnderSampler -from ..utils import Substitution +from ..utils import Substitution, check_target_type from ..utils._docstring import _random_state_docstring @@ -240,6 +240,7 @@ def fit(self, X, y): self : object Returns self. """ + check_target_type(y) # RandomUnderSampler is not supporting sample_weight. We need to pass # None. return self._fit(X, y, self.max_samples, sample_weight=None) diff --git a/imblearn/ensemble/_easy_ensemble.py b/imblearn/ensemble/_easy_ensemble.py index f846f18a9..47ffa6338 100644 --- a/imblearn/ensemble/_easy_ensemble.py +++ b/imblearn/ensemble/_easy_ensemble.py @@ -17,7 +17,7 @@ from .base import BaseEnsembleSampler from ..under_sampling import RandomUnderSampler from ..under_sampling.base import BaseUnderSampler -from ..utils import Substitution +from ..utils import Substitution, check_target_type from ..utils._docstring import _random_state_docstring from ..pipeline import Pipeline @@ -290,6 +290,7 @@ def fit(self, X, y): self : object Returns self. """ + check_target_type(y) # RandomUnderSampler is not supporting sample_weight. We need to pass # None. return self._fit(X, y, self.max_samples, sample_weight=None) diff --git a/imblearn/ensemble/_weight_boosting.py b/imblearn/ensemble/_weight_boosting.py index 6c2fd62c8..b9a90118e 100644 --- a/imblearn/ensemble/_weight_boosting.py +++ b/imblearn/ensemble/_weight_boosting.py @@ -12,7 +12,7 @@ from ..under_sampling.base import BaseUnderSampler from ..under_sampling import RandomUnderSampler from ..pipeline import make_pipeline -from ..utils import Substitution +from ..utils import Substitution, check_target_type from ..utils._docstring import _random_state_docstring @@ -146,6 +146,7 @@ def fit(self, X, y, sample_weight=None): Returns self. """ + check_target_type(y) self.samplers_ = [] self.pipelines_ = [] super().fit(X, y, sample_weight) diff --git a/imblearn/ensemble/tests/test_weight_boosting.py b/imblearn/ensemble/tests/test_weight_boosting.py index 8cc38776e..2027da983 100644 --- a/imblearn/ensemble/tests/test_weight_boosting.py +++ b/imblearn/ensemble/tests/test_weight_boosting.py @@ -23,8 +23,7 @@ def imbalanced_dataset(): [({"n_estimators": 'whatever'}, "n_estimators must be an integer"), ({"n_estimators": -100}, "n_estimators must be greater than zero")] ) -def test_balanced_random_forest_error(imbalanced_dataset, boosting_params, - err_msg): +def test_rusboost_error(imbalanced_dataset, boosting_params, err_msg): rusboost = RUSBoostClassifier(**boosting_params) with pytest.raises(ValueError, match=err_msg): rusboost.fit(*imbalanced_dataset) diff --git a/imblearn/over_sampling/tests/test_smote_nc.py b/imblearn/over_sampling/tests/test_smote_nc.py index 36a9ac788..b9533bc76 100644 --- a/imblearn/over_sampling/tests/test_smote_nc.py +++ b/imblearn/over_sampling/tests/test_smote_nc.py @@ -131,7 +131,8 @@ def test_smotenc_check_target_type(): smote.fit_resample(X, y) rng = np.random.RandomState(42) y = rng.randint(2, size=(20, 3)) - with pytest.raises(ValueError, match="'y' should encode the multiclass"): + msg = "Multilabel and multioutput targets are not supported." + with pytest.raises(ValueError, match=msg): smote.fit_resample(X, y) diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index 5db8a0a3f..abad3d554 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -88,8 +88,9 @@ def check_target_type(y, indicate_one_vs_all=False): if type_y == 'multilabel-indicator': if np.any(y.sum(axis=1) > 1): raise ValueError( - "When 'y' corresponds to '{}', 'y' should encode the " - "multiclass (a single 1 by row).".format(type_y)) + "Imbalanced-learn currently supports binary, multiclass and " + "binarized encoded multiclasss targets. Multilabel and " + "multioutput targets are not supported.") y = y.argmax(axis=1) return (y, type_y == 'multilabel-indicator') if indicate_one_vs_all else y diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 72960c9c8..bb0734c9e 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -17,7 +17,7 @@ from scipy import sparse from sklearn.base import clone -from sklearn.datasets import make_classification +from sklearn.datasets import make_classification, make_multilabel_classification # noqa from sklearn.cluster import KMeans from sklearn.preprocessing import label_binarize from sklearn.utils.estimator_checks import check_estimator \ @@ -27,6 +27,7 @@ from sklearn.utils.testing import set_random_state from sklearn.utils.multiclass import type_of_target +from imblearn.base import BaseSampler from imblearn.over_sampling.base import BaseOverSampler from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler from imblearn.ensemble.base import BaseEnsembleSampler @@ -54,10 +55,18 @@ def _yield_sampler_checks(name, Estimator): yield check_samplers_sample_indices +def _yield_classifier_checks(name, Estimator): + yield check_classifier_on_multilabel_or_multioutput_targets + + def _yield_all_checks(name, estimator): # trigger our checks if this is a SamplerMixin if hasattr(estimator, 'fit_resample'): - yield from _yield_sampler_checks(name, estimator) + for check in _yield_sampler_checks(name, estimator): + yield check + if hasattr(estimator, 'predict'): + for check in _yield_classifier_checks(name, estimator): + yield check def check_estimator(Estimator, run_sampler_tests=True): @@ -99,7 +108,8 @@ def check_target_type(name, Estimator): # if the target is multilabel then we should raise an error rng = np.random.RandomState(42) y = rng.randint(2, size=(20, 3)) - with pytest.raises(ValueError, match="'y' should encode the multiclass"): + msg = "Multilabel and multioutput targets are not supported." + with pytest.raises(ValueError, match=msg): estimator.fit_resample(X, y) @@ -342,3 +352,11 @@ def check_samplers_sample_indices(name, Sampler): assert hasattr(sampler, 'sample_indices_') is sample_indices else: assert not hasattr(sampler, 'sample_indices_') + + +def check_classifier_on_multilabel_or_multioutput_targets(name, Estimator): + estimator = Estimator() + X, y = make_multilabel_classification(n_samples=30) + msg = "Multilabel and multioutput targets are not supported." + with pytest.raises(ValueError, match=msg): + estimator.fit(X, y)