Skip to content

Commit e936bdd

Browse files
committed
ENH dataframe in/out for all samplers
1 parent 45b538c commit e936bdd

File tree

5 files changed

+48
-23
lines changed

5 files changed

+48
-23
lines changed

doc/whats_new/v0.6.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ Enhancement
5454
- :class:`imblearn.under_sampling.RandomUnderSampler`,
5555
:class:`imblearn.over_sampling.RandomOverSampler` can resample when non
5656
finite values are present in ``X``.
57-
:pr:`643` by `Guillaume Lemaitre <glemaitre>`.
57+
:pr:`643` by :user:`Guillaume Lemaitre <glemaitre>`.
58+
59+
- All samplers will output a Pandas DataFrame if a Pandas DataFrame was given
60+
as an input.
61+
:pr:`644` by :user:`Guillaume Lemaitre <glemaitre>`.
5862

5963
Deprecation
6064
...........

imblearn/base.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,20 @@ def fit_resample(self, X, y):
7878

7979
output = self._fit_resample(X, y)
8080

81+
if self._columns is not None:
82+
import pandas as pd
83+
X_ = pd.DataFrame(output[0], columns=self._columns)
84+
else:
85+
X_ = output[0]
86+
8187
if binarize_y:
8288
y_sampled = label_binarize(output[1], np.unique(y))
8389
if len(output) == 2:
84-
return output[0], y_sampled
85-
return output[0], y_sampled, output[2]
86-
return output
90+
return X_, y_sampled
91+
return X_, y_sampled, output[2]
92+
if len(output) == 2:
93+
return X_, output[1]
94+
return X_, output[1], output[2]
8795

8896
# define an alias for back-compatibility
8997
fit_sample = fit_resample
@@ -124,8 +132,9 @@ class BaseSampler(SamplerMixin):
124132
def __init__(self, sampling_strategy="auto"):
125133
self.sampling_strategy = sampling_strategy
126134

127-
@staticmethod
128-
def _check_X_y(X, y, accept_sparse=None):
135+
def _check_X_y(self, X, y, accept_sparse=None):
136+
# store the columns name to reconstruct a dataframe
137+
self._columns = X.columns if hasattr(X, "loc") else None
129138
if accept_sparse is None:
130139
accept_sparse = ["csr", "csc"]
131140
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
@@ -238,6 +247,8 @@ def fit_resample(self, X, y):
238247
y_resampled : array-like of shape (n_samples_new,)
239248
The corresponding label of `X_resampled`.
240249
"""
250+
# store the columns name to reconstruct a dataframe
251+
self._columns = X.columns if hasattr(X, "loc") else None
241252
if self.validate:
242253
check_classification_targets(y)
243254
X, y, binarize_y = self._check_X_y(
@@ -250,12 +261,20 @@ def fit_resample(self, X, y):
250261

251262
output = self._fit_resample(X, y)
252263

264+
if self._columns is not None:
265+
import pandas as pd
266+
X_ = pd.DataFrame(output[0], columns=self._columns)
267+
else:
268+
X_ = output[0]
269+
253270
if self.validate and binarize_y:
254271
y_sampled = label_binarize(output[1], np.unique(y))
255272
if len(output) == 2:
256-
return output[0], y_sampled
257-
return output[0], y_sampled, output[2]
258-
return output
273+
return X_, y_sampled
274+
return X_, y_sampled, output[2]
275+
if len(output) == 2:
276+
return X_, output[1]
277+
return X_, output[1], output[2]
259278

260279
def _fit_resample(self, X, y):
261280
func = _identity if self.func is None else self.func

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,12 @@ def __init__(self, sampling_strategy="auto", random_state=None):
7474
super().__init__(sampling_strategy=sampling_strategy)
7575
self.random_state = random_state
7676

77-
@staticmethod
78-
def _check_X_y(X, y):
77+
def _check_X_y(self, X, y):
78+
# store the columns name to reconstruct a dataframe
79+
self._columns = X.columns if hasattr(X, "loc") else None
7980
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
80-
if not hasattr(X, "loc"):
81-
# Do not convert dataframe
82-
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
83-
force_all_finite=False)
81+
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
82+
force_all_finite=False)
8483
y = check_array(
8584
y, accept_sparse=["csr", "csc"], dtype=None, ensure_2d=False
8685
)

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,12 @@ def __init__(
8080
self.random_state = random_state
8181
self.replacement = replacement
8282

83-
@staticmethod
84-
def _check_X_y(X, y):
83+
def _check_X_y(self, X, y):
84+
# store the columns name to reconstruct a dataframe
85+
self._columns = X.columns if hasattr(X, "loc") else None
8586
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
86-
if not hasattr(X, "loc"):
87-
# Do not convert dataframe
88-
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
89-
force_all_finite=False)
87+
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
88+
force_all_finite=False)
9089
y = check_array(
9190
y, accept_sparse=["csr", "csc"], dtype=None, ensure_2d=False
9291
)

imblearn/utils/estimator_checks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def check_samplers_pandas(name, Sampler):
262262
weights=[0.2, 0.3, 0.5],
263263
random_state=0,
264264
)
265-
X_pd = pd.DataFrame(X)
265+
X_pd = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
266266
sampler = Sampler()
267267
if isinstance(Sampler(), NearMiss):
268268
samplers = [Sampler(version=version) for version in (1, 2, 3)]
@@ -274,7 +274,11 @@ def check_samplers_pandas(name, Sampler):
274274
set_random_state(sampler)
275275
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y)
276276
X_res, y_res = sampler.fit_resample(X, y)
277-
assert_allclose(X_res_pd, X_res)
277+
278+
# check that we return a pandas dataframe if a dataframe was given in
279+
assert isinstance(X_res_pd, pd.DataFrame)
280+
assert X_pd.columns.to_list() == X_res_pd.columns.to_list()
281+
assert_allclose(X_res_pd.to_numpy(), X_res)
278282
assert_allclose(y_res_pd, y_res)
279283

280284

0 commit comments

Comments
 (0)