diff --git a/doc/introduction.rst b/doc/introduction.rst index 6b8aa8cf3..1ea72326f 100644 --- a/doc/introduction.rst +++ b/doc/introduction.rst @@ -30,6 +30,12 @@ Imbalanced-learn samplers accept the same inputs that in scikit-learn: matrices; * ``targets``: array-like (1-D list, pandas.Series, numpy.array). +The output will be of the following type: + +* ``data_resampled``: array-like (2-D list, pandas.Dataframe, numpy.array) or + sparse matrices; + * ``targets_resampled``: 1-D numpy.array. + .. topic:: Sparse input For sparse input the data is **converted to the Compressed Sparse Rows diff --git a/doc/whats_new/v0.6.rst b/doc/whats_new/v0.6.rst index e5aedaf72..2d99e44f8 100644 --- a/doc/whats_new/v0.6.rst +++ b/doc/whats_new/v0.6.rst @@ -54,7 +54,11 @@ Enhancement - :class:`imblearn.under_sampling.RandomUnderSampler`, :class:`imblearn.over_sampling.RandomOverSampler` can resample when non finite values are present in ``X``. - :pr:`643` by `Guillaume Lemaitre `. + :pr:`643` by :user:`Guillaume Lemaitre `. + +- All samplers will output a Pandas DataFrame if a Pandas DataFrame was given + as an input. + :pr:`644` by :user:`Guillaume Lemaitre `. Deprecation ........... diff --git a/imblearn/base.py b/imblearn/base.py index 6f57d1240..b182f3e72 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -32,7 +32,8 @@ def fit(self, X, y): Parameters ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) + X : {array-like, dataframe, sparse matrix} of shape \ + (n_samples, n_features) Data array. y : array-like of shape (n_samples,) @@ -54,7 +55,8 @@ def fit_resample(self, X, y): Parameters ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) + X : {array-like, dataframe, sparse matrix} of shape \ + (n_samples, n_features) Matrix containing the data which have to be sampled. y : array-like of shape (n_samples,) @@ -62,7 +64,7 @@ def fit_resample(self, X, y): Returns ------- - X_resampled : {array-like, sparse matrix} of shape \ + X_resampled : {array-like, dataframe, sparse matrix} of shape \ (n_samples_new, n_features) The array containing the resampled data. @@ -78,12 +80,20 @@ def fit_resample(self, X, y): output = self._fit_resample(X, y) + if self._columns is not None: + import pandas as pd + X_ = pd.DataFrame(output[0], columns=self._columns) + else: + X_ = output[0] + if binarize_y: y_sampled = label_binarize(output[1], np.unique(y)) if len(output) == 2: - return output[0], y_sampled - return output[0], y_sampled, output[2] - return output + return X_, y_sampled + return X_, y_sampled, output[2] + if len(output) == 2: + return X_, output[1] + return X_, output[1], output[2] # define an alias for back-compatibility fit_sample = fit_resample @@ -124,8 +134,9 @@ class BaseSampler(SamplerMixin): def __init__(self, sampling_strategy="auto"): self.sampling_strategy = sampling_strategy - @staticmethod - def _check_X_y(X, y, accept_sparse=None): + def _check_X_y(self, X, y, accept_sparse=None): + # store the columns name to reconstruct a dataframe + self._columns = X.columns if hasattr(X, "loc") else None if accept_sparse is None: accept_sparse = ["csr", "csc"] y, binarize_y = check_target_type(y, indicate_one_vs_all=True) @@ -238,6 +249,8 @@ def fit_resample(self, X, y): y_resampled : array-like of shape (n_samples_new,) The corresponding label of `X_resampled`. """ + # store the columns name to reconstruct a dataframe + self._columns = X.columns if hasattr(X, "loc") else None if self.validate: check_classification_targets(y) X, y, binarize_y = self._check_X_y( @@ -250,12 +263,20 @@ def fit_resample(self, X, y): output = self._fit_resample(X, y) + if self._columns is not None: + import pandas as pd + X_ = pd.DataFrame(output[0], columns=self._columns) + else: + X_ = output[0] + if self.validate and binarize_y: y_sampled = label_binarize(output[1], np.unique(y)) if len(output) == 2: - return output[0], y_sampled - return output[0], y_sampled, output[2] - return output + return X_, y_sampled + return X_, y_sampled, output[2] + if len(output) == 2: + return X_, output[1] + return X_, output[1], output[2] def _fit_resample(self, X, y): func = _identity if self.func is None else self.func diff --git a/imblearn/ensemble/base.py b/imblearn/ensemble/base.py deleted file mode 100644 index 935bb3dfe..000000000 --- a/imblearn/ensemble/base.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Base class for the ensemble method. -""" -# Authors: Guillaume Lemaitre -# License: MIT - -import numpy as np - -from sklearn.preprocessing import label_binarize - -from ..base import BaseSampler -from ..utils import check_sampling_strategy - - -class BaseEnsembleSampler(BaseSampler): - """Base class for ensemble algorithms. - - Warning: This class should not be used directly. Use the derive classes - instead. - """ - - _sampling_type = "ensemble" - - def fit_resample(self, X, y): - """Resample the dataset. - - Parameters - ---------- - X : {array-like, sparse matrix}, shape (n_samples, n_features) - Matrix containing the data which have to be sampled. - - y : array-like, shape (n_samples,) - Corresponding label for each sample in X. - - Returns - ------- - X_resampled : {ndarray, sparse matrix}, shape \ -(n_subset, n_samples_new, n_features) - The array containing the resampled data. - - y_resampled : ndarray, shape (n_subset, n_samples_new) - The corresponding label of `X_resampled` - - """ - # Ensemble are a bit specific since they are returning an array of - # resampled arrays. - X, y, binarize_y = self._check_X_y(X, y) - - self.sampling_strategy_ = check_sampling_strategy( - self.sampling_strategy, y, self._sampling_type - ) - - output = self._fit_resample(X, y) - - if binarize_y: - y_resampled = output[1] - classes = np.unique(y) - y_resampled_encoded = np.array( - [label_binarize(batch_y, classes) for batch_y in y_resampled] - ) - if len(output) == 2: - return output[0], y_resampled_encoded - return output[0], y_resampled_encoded, output[2] - return output diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index 953f16641..ea8b4d18b 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -74,13 +74,12 @@ def __init__(self, sampling_strategy="auto", random_state=None): super().__init__(sampling_strategy=sampling_strategy) self.random_state = random_state - @staticmethod - def _check_X_y(X, y): + def _check_X_y(self, X, y): + # store the columns name to reconstruct a dataframe + self._columns = X.columns if hasattr(X, "loc") else None y, binarize_y = check_target_type(y, indicate_one_vs_all=True) - if not hasattr(X, "loc"): - # Do not convert dataframe - X = check_array(X, accept_sparse=["csr", "csc"], dtype=None, - force_all_finite=False) + X = check_array(X, accept_sparse=["csr", "csc"], dtype=None, + force_all_finite=False) y = check_array( y, accept_sparse=["csr", "csc"], dtype=None, ensure_2d=False ) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index c583abb20..20e979cd3 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -900,11 +900,12 @@ def __init__( ) self.categorical_features = categorical_features - @staticmethod - def _check_X_y(X, y): + def _check_X_y(self, X, y): """Overwrite the checking to let pass some string for categorical features. """ + # store the columns name to reconstruct a dataframe + self._columns = X.columns if hasattr(X, "loc") else None y, binarize_y = check_target_type(y, indicate_one_vs_all=True) X, y = check_X_y(X, y, accept_sparse=["csr", "csc"], dtype=None) return X, y, binarize_y diff --git a/imblearn/over_sampling/tests/test_smote_nc.py b/imblearn/over_sampling/tests/test_smote_nc.py index a495c775d..c7ba80caa 100644 --- a/imblearn/over_sampling/tests/test_smote_nc.py +++ b/imblearn/over_sampling/tests/test_smote_nc.py @@ -13,6 +13,7 @@ from sklearn.datasets import make_classification from sklearn.utils._testing import assert_allclose +from sklearn.utils._testing import assert_array_equal from imblearn.over_sampling import SMOTENC @@ -184,7 +185,7 @@ def test_smotenc_pandas(): smote = SMOTENC(categorical_features=categorical_features, random_state=0) X_res_pd, y_res_pd = smote.fit_resample(X_pd, y) X_res, y_res = smote.fit_resample(X, y) - assert X_res_pd.tolist() == X_res.tolist() + assert_array_equal(X_res_pd.to_numpy(), X_res) assert_allclose(y_res_pd, y_res) diff --git a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py index 02f014f58..6301822ea 100644 --- a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py @@ -80,13 +80,12 @@ def __init__( self.random_state = random_state self.replacement = replacement - @staticmethod - def _check_X_y(X, y): + def _check_X_y(self, X, y): + # store the columns name to reconstruct a dataframe + self._columns = X.columns if hasattr(X, "loc") else None y, binarize_y = check_target_type(y, indicate_one_vs_all=True) - if not hasattr(X, "loc"): - # Do not convert dataframe - X = check_array(X, accept_sparse=["csr", "csc"], dtype=None, - force_all_finite=False) + X = check_array(X, accept_sparse=["csr", "csc"], dtype=None, + force_all_finite=False) y = check_array( y, accept_sparse=["csr", "csc"], dtype=None, ensure_2d=False ) diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index ce885cf56..4fef2a13b 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -30,7 +30,6 @@ from imblearn.over_sampling.base import BaseOverSampler from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler -from imblearn.ensemble.base import BaseEnsembleSampler from imblearn.under_sampling import NearMiss, ClusterCentroids @@ -168,12 +167,6 @@ def check_samplers_fit_resample(name, Sampler): for class_sample in target_stats.keys() if class_sample != class_minority ) - elif isinstance(sampler, BaseEnsembleSampler): - y_ensemble = y_res[0] - n_samples = min(target_stats.values()) - assert all( - value == n_samples for value in Counter(y_ensemble).values() - ) def check_samplers_sampling_strategy_fit_resample(name, Sampler): @@ -202,12 +195,6 @@ def check_samplers_sampling_strategy_fit_resample(name, Sampler): sampler.set_params(sampling_strategy=sampling_strategy) X_res, y_res = sampler.fit_resample(X, y) assert Counter(y_res)[1] == expected_stat - if isinstance(sampler, BaseEnsembleSampler): - sampling_strategy = {2: 201, 0: 201} - sampler.set_params(sampling_strategy=sampling_strategy) - X_res, y_res = sampler.fit_resample(X, y) - y_ensemble = y_res[0] - assert Counter(y_ensemble)[1] == expected_stat def check_samplers_sparse(name, Sampler): @@ -239,17 +226,9 @@ def check_samplers_sparse(name, Sampler): set_random_state(sampler) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) X_res, y_res = sampler.fit_resample(X, y) - if not isinstance(sampler, BaseEnsembleSampler): - assert sparse.issparse(X_res_sparse) - assert_allclose(X_res_sparse.A, X_res) - assert_allclose(y_res_sparse, y_res) - else: - for x_sp, x, y_sp, y in zip( - X_res_sparse, X_res, y_res_sparse, y_res - ): - assert sparse.issparse(x_sp) - assert_allclose(x_sp.A, x) - assert_allclose(y_sp, y) + assert sparse.issparse(X_res_sparse) + assert_allclose(X_res_sparse.A, X_res) + assert_allclose(y_res_sparse, y_res) def check_samplers_pandas(name, Sampler): @@ -262,7 +241,7 @@ def check_samplers_pandas(name, Sampler): weights=[0.2, 0.3, 0.5], random_state=0, ) - X_pd = pd.DataFrame(X) + X_pd = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])]) sampler = Sampler() if isinstance(Sampler(), NearMiss): samplers = [Sampler(version=version) for version in (1, 2, 3)] @@ -274,7 +253,11 @@ def check_samplers_pandas(name, Sampler): set_random_state(sampler) X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y) X_res, y_res = sampler.fit_resample(X, y) - assert_allclose(X_res_pd, X_res) + + # check that we return a pandas dataframe if a dataframe was given in + assert isinstance(X_res_pd, pd.DataFrame) + assert X_pd.columns.to_list() == X_res_pd.columns.to_list() + assert_allclose(X_res_pd.to_numpy(), X_res) assert_allclose(y_res_pd, y_res) @@ -293,13 +276,8 @@ def check_samplers_multiclass_ova(name, Sampler): X_res, y_res = sampler.fit_resample(X, y) X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova) assert_allclose(X_res, X_res_ova) - if issubclass(Sampler, BaseEnsembleSampler): - for batch_y, batch_y_ova in zip(y_res, y_res_ova): - assert type_of_target(batch_y_ova) == type_of_target(y_ova) - assert_allclose(batch_y, batch_y_ova.argmax(axis=1)) - else: - assert type_of_target(y_res_ova) == type_of_target(y_ova) - assert_allclose(y_res, y_res_ova.argmax(axis=1)) + assert type_of_target(y_res_ova) == type_of_target(y_ova) + assert_allclose(y_res, y_res_ova.argmax(axis=1)) def check_samplers_preserve_dtype(name, Sampler):