Skip to content

Commit 158258e

Browse files
authored
ENH dataframe in/out for all samplers (#644)
1 parent b606cb9 commit 158258e

File tree

9 files changed

+69
-124
lines changed

9 files changed

+69
-124
lines changed

doc/introduction.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ Imbalanced-learn samplers accept the same inputs that in scikit-learn:
3030
matrices;
3131
* ``targets``: array-like (1-D list, pandas.Series, numpy.array).
3232

33+
The output will be of the following type:
34+
35+
* ``data_resampled``: array-like (2-D list, pandas.Dataframe, numpy.array) or
36+
sparse matrices;
37+
* ``targets_resampled``: 1-D numpy.array.
38+
3339
.. topic:: Sparse input
3440

3541
For sparse input the data is **converted to the Compressed Sparse Rows

doc/whats_new/v0.6.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ Enhancement
6767
- :class:`imblearn.under_sampling.RandomUnderSampler`,
6868
:class:`imblearn.over_sampling.RandomOverSampler` can resample when non
6969
finite values are present in ``X``.
70-
:pr:`643` by `Guillaume Lemaitre <glemaitre>`.
70+
:pr:`643` by :user:`Guillaume Lemaitre <glemaitre>`.
71+
72+
- All samplers will output a Pandas DataFrame if a Pandas DataFrame was given
73+
as an input.
74+
:pr:`644` by :user:`Guillaume Lemaitre <glemaitre>`.
7175

7276
- The samples generation in
7377
:class:`imblearn.over_sampling.SMOTE`,

imblearn/base.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def fit(self, X, y):
3232
3333
Parameters
3434
----------
35-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
35+
X : {array-like, dataframe, sparse matrix} of shape \
36+
(n_samples, n_features)
3637
Data array.
3738
3839
y : array-like of shape (n_samples,)
@@ -54,15 +55,16 @@ def fit_resample(self, X, y):
5455
5556
Parameters
5657
----------
57-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
58+
X : {array-like, dataframe, sparse matrix} of shape \
59+
(n_samples, n_features)
5860
Matrix containing the data which have to be sampled.
5961
6062
y : array-like of shape (n_samples,)
6163
Corresponding label for each sample in X.
6264
6365
Returns
6466
-------
65-
X_resampled : {array-like, sparse matrix} of shape \
67+
X_resampled : {array-like, dataframe, sparse matrix} of shape \
6668
(n_samples_new, n_features)
6769
The array containing the resampled data.
6870
@@ -78,12 +80,20 @@ def fit_resample(self, X, y):
7880

7981
output = self._fit_resample(X, y)
8082

83+
if self._columns is not None:
84+
import pandas as pd
85+
X_ = pd.DataFrame(output[0], columns=self._columns)
86+
else:
87+
X_ = output[0]
88+
8189
if binarize_y:
8290
y_sampled = label_binarize(output[1], np.unique(y))
8391
if len(output) == 2:
84-
return output[0], y_sampled
85-
return output[0], y_sampled, output[2]
86-
return output
92+
return X_, y_sampled
93+
return X_, y_sampled, output[2]
94+
if len(output) == 2:
95+
return X_, output[1]
96+
return X_, output[1], output[2]
8797

8898
# define an alias for back-compatibility
8999
fit_sample = fit_resample
@@ -124,8 +134,9 @@ class BaseSampler(SamplerMixin):
124134
def __init__(self, sampling_strategy="auto"):
125135
self.sampling_strategy = sampling_strategy
126136

127-
@staticmethod
128-
def _check_X_y(X, y, accept_sparse=None):
137+
def _check_X_y(self, X, y, accept_sparse=None):
138+
# store the columns name to reconstruct a dataframe
139+
self._columns = X.columns if hasattr(X, "loc") else None
129140
if accept_sparse is None:
130141
accept_sparse = ["csr", "csc"]
131142
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
@@ -238,6 +249,8 @@ def fit_resample(self, X, y):
238249
y_resampled : array-like of shape (n_samples_new,)
239250
The corresponding label of `X_resampled`.
240251
"""
252+
# store the columns name to reconstruct a dataframe
253+
self._columns = X.columns if hasattr(X, "loc") else None
241254
if self.validate:
242255
check_classification_targets(y)
243256
X, y, binarize_y = self._check_X_y(
@@ -250,12 +263,20 @@ def fit_resample(self, X, y):
250263

251264
output = self._fit_resample(X, y)
252265

266+
if self._columns is not None:
267+
import pandas as pd
268+
X_ = pd.DataFrame(output[0], columns=self._columns)
269+
else:
270+
X_ = output[0]
271+
253272
if self.validate and binarize_y:
254273
y_sampled = label_binarize(output[1], np.unique(y))
255274
if len(output) == 2:
256-
return output[0], y_sampled
257-
return output[0], y_sampled, output[2]
258-
return output
275+
return X_, y_sampled
276+
return X_, y_sampled, output[2]
277+
if len(output) == 2:
278+
return X_, output[1]
279+
return X_, output[1], output[2]
259280

260281
def _fit_resample(self, X, y):
261282
func = _identity if self.func is None else self.func

imblearn/ensemble/base.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,12 @@ def __init__(self, sampling_strategy="auto", random_state=None):
7474
super().__init__(sampling_strategy=sampling_strategy)
7575
self.random_state = random_state
7676

77-
@staticmethod
78-
def _check_X_y(X, y):
77+
def _check_X_y(self, X, y):
78+
# store the columns name to reconstruct a dataframe
79+
self._columns = X.columns if hasattr(X, "loc") else None
7980
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
80-
if not hasattr(X, "loc"):
81-
# Do not convert dataframe
82-
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
83-
force_all_finite=False)
81+
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
82+
force_all_finite=False)
8483
y = check_array(
8584
y, accept_sparse=["csr", "csc"], dtype=None, ensure_2d=False
8685
)

imblearn/over_sampling/_smote.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,11 +888,12 @@ def __init__(
888888
)
889889
self.categorical_features = categorical_features
890890

891-
@staticmethod
892-
def _check_X_y(X, y):
891+
def _check_X_y(self, X, y):
893892
"""Overwrite the checking to let pass some string for categorical
894893
features.
895894
"""
895+
# store the columns name to reconstruct a dataframe
896+
self._columns = X.columns if hasattr(X, "loc") else None
896897
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
897898
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"], dtype=None)
898899
return X, y, binarize_y

imblearn/over_sampling/tests/test_smote_nc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from sklearn.datasets import make_classification
1515
from sklearn.utils._testing import assert_allclose
16+
from sklearn.utils._testing import assert_array_equal
1617

1718
from imblearn.over_sampling import SMOTENC
1819

@@ -184,7 +185,7 @@ def test_smotenc_pandas():
184185
smote = SMOTENC(categorical_features=categorical_features, random_state=0)
185186
X_res_pd, y_res_pd = smote.fit_resample(X_pd, y)
186187
X_res, y_res = smote.fit_resample(X, y)
187-
assert X_res_pd.tolist() == X_res.tolist()
188+
assert_array_equal(X_res_pd.to_numpy(), X_res)
188189
assert_allclose(y_res_pd, y_res)
189190

190191

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,12 @@ def __init__(
8080
self.random_state = random_state
8181
self.replacement = replacement
8282

83-
@staticmethod
84-
def _check_X_y(X, y):
83+
def _check_X_y(self, X, y):
84+
# store the columns name to reconstruct a dataframe
85+
self._columns = X.columns if hasattr(X, "loc") else None
8586
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
86-
if not hasattr(X, "loc"):
87-
# Do not convert dataframe
88-
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
89-
force_all_finite=False)
87+
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
88+
force_all_finite=False)
9089
y = check_array(
9190
y, accept_sparse=["csr", "csc"], dtype=None, ensure_2d=False
9291
)

imblearn/utils/estimator_checks.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from imblearn.over_sampling.base import BaseOverSampler
3232
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
33-
from imblearn.ensemble.base import BaseEnsembleSampler
3433
from imblearn.under_sampling import NearMiss, ClusterCentroids
3534

3635

@@ -168,12 +167,6 @@ def check_samplers_fit_resample(name, Sampler):
168167
for class_sample in target_stats.keys()
169168
if class_sample != class_minority
170169
)
171-
elif isinstance(sampler, BaseEnsembleSampler):
172-
y_ensemble = y_res[0]
173-
n_samples = min(target_stats.values())
174-
assert all(
175-
value == n_samples for value in Counter(y_ensemble).values()
176-
)
177170

178171

179172
def check_samplers_sampling_strategy_fit_resample(name, Sampler):
@@ -202,12 +195,6 @@ def check_samplers_sampling_strategy_fit_resample(name, Sampler):
202195
sampler.set_params(sampling_strategy=sampling_strategy)
203196
X_res, y_res = sampler.fit_resample(X, y)
204197
assert Counter(y_res)[1] == expected_stat
205-
if isinstance(sampler, BaseEnsembleSampler):
206-
sampling_strategy = {2: 201, 0: 201}
207-
sampler.set_params(sampling_strategy=sampling_strategy)
208-
X_res, y_res = sampler.fit_resample(X, y)
209-
y_ensemble = y_res[0]
210-
assert Counter(y_ensemble)[1] == expected_stat
211198

212199

213200
def check_samplers_sparse(name, Sampler):
@@ -239,17 +226,9 @@ def check_samplers_sparse(name, Sampler):
239226
set_random_state(sampler)
240227
X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y)
241228
X_res, y_res = sampler.fit_resample(X, y)
242-
if not isinstance(sampler, BaseEnsembleSampler):
243-
assert sparse.issparse(X_res_sparse)
244-
assert_allclose(X_res_sparse.A, X_res)
245-
assert_allclose(y_res_sparse, y_res)
246-
else:
247-
for x_sp, x, y_sp, y in zip(
248-
X_res_sparse, X_res, y_res_sparse, y_res
249-
):
250-
assert sparse.issparse(x_sp)
251-
assert_allclose(x_sp.A, x)
252-
assert_allclose(y_sp, y)
229+
assert sparse.issparse(X_res_sparse)
230+
assert_allclose(X_res_sparse.A, X_res)
231+
assert_allclose(y_res_sparse, y_res)
253232

254233

255234
def check_samplers_pandas(name, Sampler):
@@ -262,7 +241,7 @@ def check_samplers_pandas(name, Sampler):
262241
weights=[0.2, 0.3, 0.5],
263242
random_state=0,
264243
)
265-
X_pd = pd.DataFrame(X)
244+
X_pd = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
266245
sampler = Sampler()
267246
if isinstance(Sampler(), NearMiss):
268247
samplers = [Sampler(version=version) for version in (1, 2, 3)]
@@ -274,7 +253,11 @@ def check_samplers_pandas(name, Sampler):
274253
set_random_state(sampler)
275254
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y)
276255
X_res, y_res = sampler.fit_resample(X, y)
277-
assert_allclose(X_res_pd, X_res)
256+
257+
# check that we return a pandas dataframe if a dataframe was given in
258+
assert isinstance(X_res_pd, pd.DataFrame)
259+
assert X_pd.columns.to_list() == X_res_pd.columns.to_list()
260+
assert_allclose(X_res_pd.to_numpy(), X_res)
278261
assert_allclose(y_res_pd, y_res)
279262

280263

@@ -293,13 +276,8 @@ def check_samplers_multiclass_ova(name, Sampler):
293276
X_res, y_res = sampler.fit_resample(X, y)
294277
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
295278
assert_allclose(X_res, X_res_ova)
296-
if issubclass(Sampler, BaseEnsembleSampler):
297-
for batch_y, batch_y_ova in zip(y_res, y_res_ova):
298-
assert type_of_target(batch_y_ova) == type_of_target(y_ova)
299-
assert_allclose(batch_y, batch_y_ova.argmax(axis=1))
300-
else:
301-
assert type_of_target(y_res_ova) == type_of_target(y_ova)
302-
assert_allclose(y_res, y_res_ova.argmax(axis=1))
279+
assert type_of_target(y_res_ova) == type_of_target(y_ova)
280+
assert_allclose(y_res, y_res_ova.argmax(axis=1))
303281

304282

305283
def check_samplers_preserve_dtype(name, Sampler):

0 commit comments

Comments
 (0)