From dae3ba387dc8b04373a59a6326e5fbd24f983250 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Mar 2018 02:03:46 +0100 Subject: [PATCH 1/9] EHN accept one-vs-all targets --- imblearn/__init__.py | 3 +++ imblearn/base.py | 19 ++++++++++++++++--- imblearn/utils/validation.py | 28 +++++++++++++++++++++------- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/imblearn/__init__.py b/imblearn/__init__.py index 9f05adb1f..7803ca016 100644 --- a/imblearn/__init__.py +++ b/imblearn/__init__.py @@ -13,6 +13,9 @@ exceptions Module including custom warnings and error clases used across imbalanced-learn. +keras + Module which provides custom generator, layers for deep learning using + keras. metrics Module which provides metrics to quantified the classification performance with imbalanced dataset. diff --git a/imblearn/base.py b/imblearn/base.py index aa12eb365..d352d9696 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -9,9 +9,13 @@ import logging from abc import ABCMeta, abstractmethod +import numpy as np + from sklearn.base import BaseEstimator from sklearn.externals import six +from sklearn.preprocessing import label_binarize from sklearn.utils import check_X_y +from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import check_is_fitted from .utils import check_ratio, check_target_type, hash_X_y @@ -54,14 +58,23 @@ def sample(self, X, y): The corresponding label of `X_resampled` """ - # Check the consistency of X and y + y, binarize_y = check_target_type(y, indicate_one_vs_all=True) X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) check_is_fitted(self, 'ratio_') self._check_X_y(X, y) - return self._sample(X, y) + output = self._sample(X, y) + + if binarize_y: + y_sampled = label_binarize(output[1], np.unique(y)) + if len(output) == 2: + return output[0], y_sampled + else: + return output[0], y_sampled, output[2] + else: + return output def fit_sample(self, X, y): """Fit the statistics and resample the data directly. @@ -152,8 +165,8 @@ def fit(self, X, y): Return self. """ - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) y = check_target_type(y) + X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) self.X_hash_, self.y_hash_ = hash_X_y(X, y) # self.sampling_type is already checked in check_ratio self.ratio_ = check_ratio(self.ratio, y, self._sampling_type) diff --git a/imblearn/utils/validation.py b/imblearn/utils/validation.py index d009dacbe..17fa7e93a 100644 --- a/imblearn/utils/validation.py +++ b/imblearn/utils/validation.py @@ -10,6 +10,7 @@ import numpy as np +from sklearn.preprocessing import label_binarize from sklearn.neighbors.base import KNeighborsMixin from sklearn.neighbors import NearestNeighbors from sklearn.externals import six, joblib @@ -19,7 +20,7 @@ SAMPLING_KIND = ('over-sampling', 'under-sampling', 'clean-sampling', 'ensemble') -TARGET_KIND = ('binary', 'multiclass') +TARGET_KIND = ('binary', 'multiclass', 'multilabel-indicator') def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): @@ -54,29 +55,42 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object) -def check_target_type(y): +def check_target_type(y, indicate_one_vs_all=False): """Check the target types to be conform to the current samplers. - The current samplers should be compatible with ``'binary'`` and - ``'multiclass'`` targets only. + The current samplers should be compatible with ``'binary'``, + ``'multilabel-indicator'`` and ``'multiclass'`` targets only. Parameters ---------- y : ndarray, - The array containing the target + The array containing the target. + + indicate_one_vs_all : bool, optional + Either to indicate if the targets are encoded in a one-vs-all fashion. Returns ------- y : ndarray, The returned target. + is_one_vs_all : bool, optional + Indicate if the target was originally encoded in a one-vs-all fashion. + Only returned if ``indicate_multilabel=True``. + """ - if type_of_target(y) not in TARGET_KIND: + type_y = type_of_target(y) + if type_y not in TARGET_KIND: # FIXME: perfectly we should raise an error but the sklearn API does # not allow for it warnings.warn("'y' should be of types {} only. Got {} instead.".format( TARGET_KIND, type_of_target(y))) - return y + + if indicate_one_vs_all: + return (y.argmax(axis=1) if type_y == 'multilabel-indicator' else y, + type_y == 'multilabel-indicator') + else: + return y.argmax(axis=1) if type_y == 'multilabel-indicator' else y def hash_X_y(X, y, n_samples=10, n_features=5): From 7487ce48d7037823af4253064e8a7460b9594c50 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Mar 2018 02:29:49 +0100 Subject: [PATCH 2/9] TST add test for check_target_type --- imblearn/utils/tests/test_validation.py | 32 +++++++++++++++++++++++++ imblearn/utils/validation.py | 1 - 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py index a1d8585df..49e9d6997 100644 --- a/imblearn/utils/tests/test_validation.py +++ b/imblearn/utils/tests/test_validation.py @@ -6,17 +6,20 @@ from collections import Counter import numpy as np +import pytest from pytest import raises from sklearn.neighbors.base import KNeighborsMixin from sklearn.neighbors import NearestNeighbors from sklearn.utils import check_random_state from sklearn.externals import joblib +from sklearn.utils.testing import assert_array_equal from imblearn.utils.testing import warns from imblearn.utils import check_neighbors_object from imblearn.utils import check_ratio from imblearn.utils import hash_X_y +from imblearn.utils import check_target_type def test_check_neighbors_object(): @@ -35,6 +38,35 @@ def test_check_neighbors_object(): check_neighbors_object(name, n_neighbors) +@pytest.mark.parametrize( + "target, output_target", + [(np.array([0, 1, 1]), np.array([0, 1, 1])), + (np.array([0, 1, 2]), np.array([0, 1, 2])), + (np.array([[0, 1], [1, 0]]), np.array([1, 0]))] +) +def test_check_target_type(target, output_target): + converted_target = check_target_type(target.astype(int)) + assert_array_equal(converted_target, output_target.astype(int)) + + +@pytest.mark.parametrize( + "target, output_target, is_ova", + [(np.array([0, 1, 1]), np.array([0, 1, 1]), False), + (np.array([0, 1, 2]), np.array([0, 1, 2]), False), + (np.array([[0, 1], [1, 0]]), np.array([1, 0]), True)] +) +def test_check_target_type_ova(target, output_target, is_ova): + converted_target, binarize_target = check_target_type( + target.astype(int), indicate_one_vs_all=True) + assert_array_equal(converted_target, output_target.astype(int)) + assert binarize_target == is_ova + + +def test_check_target_warning(): + target = np.arange(4).reshape((2, 2)) + with pytest.warns(UserWarning, message='should be of types'): + check_target_type(target) + def test_check_ratio_error(): with raises(ValueError, match="'sampling_type' should be one of"): check_ratio('auto', np.array([1, 2, 3]), 'rnd') diff --git a/imblearn/utils/validation.py b/imblearn/utils/validation.py index 17fa7e93a..58488463a 100644 --- a/imblearn/utils/validation.py +++ b/imblearn/utils/validation.py @@ -10,7 +10,6 @@ import numpy as np -from sklearn.preprocessing import label_binarize from sklearn.neighbors.base import KNeighborsMixin from sklearn.neighbors import NearestNeighbors from sklearn.externals import six, joblib From 05ae2e6ff833f033086b31b4548545f0f31c4157 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Mar 2018 02:30:45 +0100 Subject: [PATCH 3/9] PEP8 --- imblearn/utils/tests/test_validation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py index 49e9d6997..64a23eb4a 100644 --- a/imblearn/utils/tests/test_validation.py +++ b/imblearn/utils/tests/test_validation.py @@ -67,6 +67,7 @@ def test_check_target_warning(): with pytest.warns(UserWarning, message='should be of types'): check_target_type(target) + def test_check_ratio_error(): with raises(ValueError, match="'sampling_type' should be one of"): check_ratio('auto', np.array([1, 2, 3]), 'rnd') From 05b7d65798e1baad01316f3e1b9f1913f7bf7cad Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Mar 2018 02:33:55 +0100 Subject: [PATCH 4/9] TST fix pytests match warns --- imblearn/utils/tests/test_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py index 64a23eb4a..bed62617d 100644 --- a/imblearn/utils/tests/test_validation.py +++ b/imblearn/utils/tests/test_validation.py @@ -64,7 +64,7 @@ def test_check_target_type_ova(target, output_target, is_ova): def test_check_target_warning(): target = np.arange(4).reshape((2, 2)) - with pytest.warns(UserWarning, message='should be of types'): + with pytest.warns(UserWarning, match='should be of types'): check_target_type(target) From 1a27e3e80b6020fe63dfc922821a86adac36a5c7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Mar 2018 03:26:24 +0100 Subject: [PATCH 5/9] TST common test to check multiclass ova equality --- imblearn/base.py | 29 +++++++++++++++------------- imblearn/combine/smote_enn.py | 2 +- imblearn/combine/smote_tomek.py | 3 +-- imblearn/ensemble/balance_cascade.py | 3 ++- imblearn/utils/estimator_checks.py | 18 +++++++++++++++++ 5 files changed, 38 insertions(+), 17 deletions(-) diff --git a/imblearn/base.py b/imblearn/base.py index d352d9696..a44831c0b 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -15,7 +15,6 @@ from sklearn.externals import six from sklearn.preprocessing import label_binarize from sklearn.utils import check_X_y -from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import check_is_fitted from .utils import check_ratio, check_target_type, hash_X_y @@ -245,17 +244,10 @@ def __init__(self, func=None, accept_sparse=True, kw_args=None): self.kw_args = kw_args self.logger = logging.getLogger(__name__) - def _check_X_y(self, X, y): - if self.accept_sparse: - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) - else: - X, y = check_X_y(X, y, accept_sparse=False) - y = check_target_type(y) - - return X, y - def fit(self, X, y): - X, y = self._check_X_y(X, y) + y = check_target_type(y) + X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'] + if self.accept_sparse else False) self.X_hash_, self.y_hash_ = hash_X_y(X, y) # when using a sampler, ratio_ is supposed to exist after fit self.ratio_ = 'is_fitted' @@ -263,7 +255,9 @@ def fit(self, X, y): return self def _sample(self, X, y, func=None, kw_args=None): - X, y = self._check_X_y(X, y) + y, binarize_y = check_target_type(y, indicate_one_vs_all=True) + X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'] + if self.accept_sparse else False) check_is_fitted(self, 'ratio_') X_hash, y_hash = hash_X_y(X, y) if self.X_hash_ != X_hash or self.y_hash_ != y_hash: @@ -272,7 +266,16 @@ def _sample(self, X, y, func=None, kw_args=None): if func is None: func = _identity - return func(X, y, **(kw_args if self.kw_args else {})) + output = func(X, y, **(kw_args if self.kw_args else {})) + + if binarize_y: + y_sampled = label_binarize(output[1], np.unique(y)) + if len(output) == 2: + return output[0], y_sampled + else: + return output[0], y_sampled, output[2] + else: + return output def sample(self, X, y): return self._sample(X, y, func=self.func, kw_args=self.kw_args) diff --git a/imblearn/combine/smote_enn.py b/imblearn/combine/smote_enn.py index 74420472b..470919878 100644 --- a/imblearn/combine/smote_enn.py +++ b/imblearn/combine/smote_enn.py @@ -144,8 +144,8 @@ def fit(self, X, y): Return self. """ - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) y = check_target_type(y) + X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) self.ratio_ = self.ratio self.X_hash_, self.y_hash_ = hash_X_y(X, y) diff --git a/imblearn/combine/smote_tomek.py b/imblearn/combine/smote_tomek.py index b48e6510a..0748e6ef7 100644 --- a/imblearn/combine/smote_tomek.py +++ b/imblearn/combine/smote_tomek.py @@ -8,7 +8,6 @@ from __future__ import division import logging -import warnings from sklearn.utils import check_X_y @@ -153,8 +152,8 @@ def fit(self, X, y): Return self. """ - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) y = check_target_type(y) + X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) self.ratio_ = self.ratio self.X_hash_, self.y_hash_ = hash_X_y(X, y) diff --git a/imblearn/ensemble/balance_cascade.py b/imblearn/ensemble/balance_cascade.py index bc6a06c6f..6668209ea 100644 --- a/imblearn/ensemble/balance_cascade.py +++ b/imblearn/ensemble/balance_cascade.py @@ -14,7 +14,7 @@ from sklearn.model_selection import cross_val_predict from .base import BaseEnsembleSampler -from ..utils import check_ratio +from ..utils import check_ratio, check_target_type class BalanceCascade(BaseEnsembleSampler): @@ -137,6 +137,7 @@ def fit(self, X, y): """ super(BalanceCascade, self).fit(X, y) + y = check_target_type(y) self.ratio_ = check_ratio(self.ratio, y, 'under-sampling') return self diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 2184b4b12..bd37e408a 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -19,11 +19,13 @@ from sklearn.datasets import make_classification from sklearn.cluster import KMeans +from sklearn.preprocessing import label_binarize from sklearn.utils.estimator_checks import check_estimator \ as sklearn_check_estimator, check_parameters_default_constructible from sklearn.exceptions import NotFittedError from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import set_random_state +from sklearn.utils.multiclass import type_of_target from imblearn.over_sampling.base import BaseOverSampler from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler @@ -44,6 +46,7 @@ def _yield_sampler_checks(name, Estimator): yield check_samplers_ratio_fit_sample yield check_samplers_sparse yield check_samplers_pandas + yield check_samplers_multiclass_ova def _yield_all_checks(name, estimator): @@ -253,3 +256,18 @@ def check_samplers_pandas(name, Sampler): X_res, y_res = sampler.fit_sample(X, y) assert_allclose(X_res_pd, X_res) assert_allclose(y_res_pd, y_res) + + +def check_samplers_multiclass_ova(name, Sampler): + # Check that multiclass target lead to the same results than OVA encoding + X, y = make_classification(n_samples=1000, n_classes=3, + n_informative=4, weights=[0.2, 0.3, 0.5], + random_state=0) + y_ova = label_binarize(y, np.unique(y)) + sampler = Sampler() + set_random_state(sampler) + X_res, y_res = sampler.fit_sample(X, y) + X_res_ova, y_res_ova = sampler.fit_sample(X, y_ova) + assert_allclose(X_res, X_res_ova) + assert type_of_target(y_res_ova) == type_of_target(y_ova) + assert_allclose(y_res, y_res_ova.argmax(axis=1)) From f5a583f7fdcbc65445ccf232a9f00d673e48250d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Mar 2018 15:37:25 +0100 Subject: [PATCH 6/9] FIX/TST ensemble handle one-vs-all encoding --- imblearn/ensemble/base.py | 50 ++++++++++++++++++++++++++++++ imblearn/utils/estimator_checks.py | 9 ++++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/imblearn/ensemble/base.py b/imblearn/ensemble/base.py index 87fbd8250..bb1d34ebc 100644 --- a/imblearn/ensemble/base.py +++ b/imblearn/ensemble/base.py @@ -4,7 +4,14 @@ # Authors: Guillaume Lemaitre # License: MIT +import numpy as np + +from sklearn.preprocessing import label_binarize +from sklearn.utils import check_X_y +from sklearn.utils.validation import check_is_fitted + from ..base import BaseSampler +from ..utils import check_target_type class BaseEnsembleSampler(BaseSampler): @@ -15,3 +22,46 @@ class BaseEnsembleSampler(BaseSampler): """ _sampling_type = 'ensemble' + + def sample(self, X, y): + """Resample the dataset. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Matrix containing the data which have to be sampled. + + y : array-like, shape (n_samples,) + Corresponding label for each sample in X. + + Returns + ------- + X_resampled : {ndarray, sparse matrix}, shape \ +(n_subset, n_samples_new, n_features) + The array containing the resampled data. + + y_resampled : ndarray, shape (n_subset, n_samples_new) + The corresponding label of `X_resampled` + + """ + # Ensemble are a bit specific since they are returning an array of + # resampled arrays. + y, binarize_y = check_target_type(y, indicate_one_vs_all=True) + X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) + + check_is_fitted(self, 'ratio_') + self._check_X_y(X, y) + + output = self._sample(X, y) + + if binarize_y: + y_resampled = output[1] + classes = np.unique(y) + y_resampled_encoded = np.array([label_binarize(batch_y, classes) + for batch_y in y_resampled]) + if len(output) == 2: + return output[0], y_resampled_encoded + else: + return output[0], y_resampled_encoded, output[2] + else: + return output diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index bd37e408a..db4cc2bd9 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -269,5 +269,10 @@ def check_samplers_multiclass_ova(name, Sampler): X_res, y_res = sampler.fit_sample(X, y) X_res_ova, y_res_ova = sampler.fit_sample(X, y_ova) assert_allclose(X_res, X_res_ova) - assert type_of_target(y_res_ova) == type_of_target(y_ova) - assert_allclose(y_res, y_res_ova.argmax(axis=1)) + if issubclass(Sampler, BaseEnsembleSampler): + for batch_y, batch_y_ova in zip(y_res, y_res_ova): + assert type_of_target(batch_y_ova) == type_of_target(y_ova) + assert_allclose(batch_y, batch_y_ova.argmax(axis=1)) + else: + assert type_of_target(y_res_ova) == type_of_target(y_ova) + assert_allclose(y_res, y_res_ova.argmax(axis=1)) From 1435b3515e81376de7822b29e449405a98dadc32 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Mar 2018 18:46:55 +0100 Subject: [PATCH 7/9] DOC remove wrong documentation module --- imblearn/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/imblearn/__init__.py b/imblearn/__init__.py index 7803ca016..9f05adb1f 100644 --- a/imblearn/__init__.py +++ b/imblearn/__init__.py @@ -13,9 +13,6 @@ exceptions Module including custom warnings and error clases used across imbalanced-learn. -keras - Module which provides custom generator, layers for deep learning using - keras. metrics Module which provides metrics to quantified the classification performance with imbalanced dataset. From 6350718f32a0c5f07507276c51cdf3c17df1bddc Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Mar 2018 18:55:21 +0100 Subject: [PATCH 8/9] add whats new entry --- doc/whats_new/v0.0.4.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats_new/v0.0.4.rst b/doc/whats_new/v0.0.4.rst index 67b012ae1..66c360e45 100644 --- a/doc/whats_new/v0.0.4.rst +++ b/doc/whats_new/v0.0.4.rst @@ -15,6 +15,9 @@ Enhancement - Document the metrics to evaluate models on imbalanced dataset. :issue:`367` by :user:`Guillaume Lemaitre `. +- Add support for one-vs-all encoded target to support keras. :issue:`410` by + :user:`Guillaume Lemaitre `. + Bug fixes ......... From dae94b90f801bfa9be6ae6867c644cbcfc472c41 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Mar 2018 18:55:54 +0100 Subject: [PATCH 9/9] correct issue --- doc/whats_new/v0.0.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.0.4.rst b/doc/whats_new/v0.0.4.rst index 66c360e45..0e3211b31 100644 --- a/doc/whats_new/v0.0.4.rst +++ b/doc/whats_new/v0.0.4.rst @@ -15,7 +15,7 @@ Enhancement - Document the metrics to evaluate models on imbalanced dataset. :issue:`367` by :user:`Guillaume Lemaitre `. -- Add support for one-vs-all encoded target to support keras. :issue:`410` by +- Add support for one-vs-all encoded target to support keras. :issue:`409` by :user:`Guillaume Lemaitre `. Bug fixes