Skip to content

[MRG] MAINT explicit fail messages on non supported targets #544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 11, 2019
3 changes: 2 additions & 1 deletion imblearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..pipeline import Pipeline
from ..under_sampling import RandomUnderSampler
from ..under_sampling.base import BaseUnderSampler
from ..utils import Substitution
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring


Expand Down Expand Up @@ -240,6 +240,7 @@ def fit(self, X, y):
self : object
Returns self.
"""
check_target_type(y)
# RandomUnderSampler is not supporting sample_weight. We need to pass
# None.
return self._fit(X, y, self.max_samples, sample_weight=None)
3 changes: 2 additions & 1 deletion imblearn/ensemble/_easy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .base import BaseEnsembleSampler
from ..under_sampling import RandomUnderSampler
from ..under_sampling.base import BaseUnderSampler
from ..utils import Substitution
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring
from ..pipeline import Pipeline

Expand Down Expand Up @@ -290,6 +290,7 @@ def fit(self, X, y):
self : object
Returns self.
"""
check_target_type(y)
# RandomUnderSampler is not supporting sample_weight. We need to pass
# None.
return self._fit(X, y, self.max_samples, sample_weight=None)
3 changes: 2 additions & 1 deletion imblearn/ensemble/_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..under_sampling.base import BaseUnderSampler
from ..under_sampling import RandomUnderSampler
from ..pipeline import make_pipeline
from ..utils import Substitution
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring


Expand Down Expand Up @@ -146,6 +146,7 @@ def fit(self, X, y, sample_weight=None):
Returns self.

"""
check_target_type(y)
self.samplers_ = []
self.pipelines_ = []
super().fit(X, y, sample_weight)
Expand Down
3 changes: 1 addition & 2 deletions imblearn/ensemble/tests/test_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def imbalanced_dataset():
[({"n_estimators": 'whatever'}, "n_estimators must be an integer"),
({"n_estimators": -100}, "n_estimators must be greater than zero")]
)
def test_balanced_random_forest_error(imbalanced_dataset, boosting_params,
err_msg):
def test_rusboost_error(imbalanced_dataset, boosting_params, err_msg):
rusboost = RUSBoostClassifier(**boosting_params)
with pytest.raises(ValueError, match=err_msg):
rusboost.fit(*imbalanced_dataset)
Expand Down
3 changes: 2 additions & 1 deletion imblearn/over_sampling/tests/test_smote_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def test_smotenc_check_target_type():
smote.fit_resample(X, y)
rng = np.random.RandomState(42)
y = rng.randint(2, size=(20, 3))
with pytest.raises(ValueError, match="'y' should encode the multiclass"):
msg = "Multilabel and multioutput targets are not supported."
with pytest.raises(ValueError, match=msg):
smote.fit_resample(X, y)


Expand Down
5 changes: 3 additions & 2 deletions imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def check_target_type(y, indicate_one_vs_all=False):
if type_y == 'multilabel-indicator':
if np.any(y.sum(axis=1) > 1):
raise ValueError(
"When 'y' corresponds to '{}', 'y' should encode the "
"multiclass (a single 1 by row).".format(type_y))
"Imbalanced-learn currently supports binary, multiclass and "
"binarized encoded multiclasss targets. Multilabel and "
"multioutput targets are not supported.")
y = y.argmax(axis=1)

return (y, type_y == 'multilabel-indicator') if indicate_one_vs_all else y
Expand Down
24 changes: 21 additions & 3 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from scipy import sparse

from sklearn.base import clone
from sklearn.datasets import make_classification
from sklearn.datasets import make_classification, make_multilabel_classification # noqa
from sklearn.cluster import KMeans
from sklearn.preprocessing import label_binarize
from sklearn.utils.estimator_checks import check_estimator \
Expand All @@ -27,6 +27,7 @@
from sklearn.utils.testing import set_random_state
from sklearn.utils.multiclass import type_of_target

from imblearn.base import BaseSampler
from imblearn.over_sampling.base import BaseOverSampler
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
from imblearn.ensemble.base import BaseEnsembleSampler
Expand Down Expand Up @@ -54,10 +55,18 @@ def _yield_sampler_checks(name, Estimator):
yield check_samplers_sample_indices


def _yield_classifier_checks(name, Estimator):
yield check_classifier_on_multilabel_or_multioutput_targets


def _yield_all_checks(name, estimator):
# trigger our checks if this is a SamplerMixin
if hasattr(estimator, 'fit_resample'):
yield from _yield_sampler_checks(name, estimator)
for check in _yield_sampler_checks(name, estimator):
yield check
if hasattr(estimator, 'predict'):
for check in _yield_classifier_checks(name, estimator):
yield check


def check_estimator(Estimator, run_sampler_tests=True):
Expand Down Expand Up @@ -99,7 +108,8 @@ def check_target_type(name, Estimator):
# if the target is multilabel then we should raise an error
rng = np.random.RandomState(42)
y = rng.randint(2, size=(20, 3))
with pytest.raises(ValueError, match="'y' should encode the multiclass"):
msg = "Multilabel and multioutput targets are not supported."
with pytest.raises(ValueError, match=msg):
estimator.fit_resample(X, y)


Expand Down Expand Up @@ -342,3 +352,11 @@ def check_samplers_sample_indices(name, Sampler):
assert hasattr(sampler, 'sample_indices_') is sample_indices
else:
assert not hasattr(sampler, 'sample_indices_')


def check_classifier_on_multilabel_or_multioutput_targets(name, Estimator):
estimator = Estimator()
X, y = make_multilabel_classification(n_samples=30)
msg = "Multilabel and multioutput targets are not supported."
with pytest.raises(ValueError, match=msg):
estimator.fit(X, y)