Skip to content

Commit 77d3dec

Browse files
committed
ENH Allows pandas series in/out
1 parent 158258e commit 77d3dec

File tree

6 files changed

+92
-27
lines changed

6 files changed

+92
-27
lines changed

doc/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ The output will be of the following type:
3434

3535
* ``data_resampled``: array-like (2-D list, pandas.Dataframe, numpy.array) or
3636
sparse matrices;
37-
* ``targets_resampled``: 1-D numpy.array.
37+
* ``targets_resampled``: 1-D numpy.array or pd.Series.
3838

3939
.. topic:: Sparse input
4040

doc/whats_new/v0.6.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ Enhancement
5757
- :class:`imblearn.under_sampling.RandomUnderSampling`,
5858
:class:`imblearn.over_sampling.RandomOverSampling`,
5959
:class:`imblearn.datasets.make_imbalance` accepts Pandas DataFrame in and
60-
will output Pandas DataFrame.
60+
will output Pandas DataFrame. Similarly, it will accepts Pandas Series in and
61+
will output Pandas Series.
6162
:pr:`636` by :user:`Guillaume Lemaitre <glemaitre>`.
6263

6364
- :class:`imblearn.FunctionSampler` accepts a parameter ``validate`` allowing

imblearn/base.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,28 @@ def fit_resample(self, X, y):
8080

8181
output = self._fit_resample(X, y)
8282

83-
if self._columns is not None:
83+
if self._X_columns is not None or self._y_name is not None:
8484
import pandas as pd
85-
X_ = pd.DataFrame(output[0], columns=self._columns)
85+
86+
if self._X_columns is not None:
87+
X_ = pd.DataFrame(output[0], columns=self._X_columns)
88+
X_ = X_.astype(self._X_dtypes)
8689
else:
8790
X_ = output[0]
8891

92+
y_ = (label_binarize(output[1], np.unique(y))
93+
if binarize_y else output[1])
94+
95+
if self._y_name is not None:
96+
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)
97+
8998
if binarize_y:
90-
y_sampled = label_binarize(output[1], np.unique(y))
9199
if len(output) == 2:
92-
return X_, y_sampled
93-
return X_, y_sampled, output[2]
100+
return X_, y_
101+
return X_, y_, output[2]
94102
if len(output) == 2:
95-
return X_, output[1]
96-
return X_, output[1], output[2]
103+
return X_, y_
104+
return X_, y_, output[2]
97105

98106
# define an alias for back-compatibility
99107
fit_sample = fit_resample
@@ -135,8 +143,22 @@ def __init__(self, sampling_strategy="auto"):
135143
self.sampling_strategy = sampling_strategy
136144

137145
def _check_X_y(self, X, y, accept_sparse=None):
138-
# store the columns name to reconstruct a dataframe
139-
self._columns = X.columns if hasattr(X, "loc") else None
146+
if hasattr(X, "loc"):
147+
# store information to build dataframe
148+
self._X_columns = X.columns
149+
self._X_dtypes = X.dtypes
150+
else:
151+
self._X_columns = None
152+
self._X_dtypes = None
153+
154+
if hasattr(y, "loc"):
155+
# store information to build a series
156+
self._y_name = y.name
157+
self._y_dtype = y.dtype
158+
else:
159+
self._y_name = None
160+
self._y_dtype = None
161+
140162
if accept_sparse is None:
141163
accept_sparse = ["csr", "csc"]
142164
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
@@ -263,20 +285,31 @@ def fit_resample(self, X, y):
263285

264286
output = self._fit_resample(X, y)
265287

266-
if self._columns is not None:
267-
import pandas as pd
268-
X_ = pd.DataFrame(output[0], columns=self._columns)
288+
if self.validate:
289+
if self._X_columns is not None or self._y_name is not None:
290+
import pandas as pd
291+
292+
if self._X_columns is not None:
293+
X_ = pd.DataFrame(output[0], columns=self._X_columns)
294+
X_ = X_.astype(self._X_dtypes)
295+
else:
296+
X_ = output[0]
297+
298+
y_ = (label_binarize(output[1], np.unique(y))
299+
if binarize_y else output[1])
300+
301+
if self._y_name is not None:
302+
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)
269303
else:
270-
X_ = output[0]
304+
X_, y_ = output[0], output[1]
271305

272-
if self.validate and binarize_y:
273-
y_sampled = label_binarize(output[1], np.unique(y))
306+
if binarize_y:
274307
if len(output) == 2:
275-
return X_, y_sampled
276-
return X_, y_sampled, output[2]
308+
return X_, y_
309+
return X_, y_, output[2]
277310
if len(output) == 2:
278-
return X_, output[1]
279-
return X_, output[1], output[2]
311+
return X_, y_
312+
return X_, y_, output[2]
280313

281314
def _fit_resample(self, X, y):
282315
func = _identity if self.func is None else self.func

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,22 @@ def __init__(self, sampling_strategy="auto", random_state=None):
7575
self.random_state = random_state
7676

7777
def _check_X_y(self, X, y):
78-
# store the columns name to reconstruct a dataframe
79-
self._columns = X.columns if hasattr(X, "loc") else None
78+
if hasattr(X, "loc"):
79+
# store information to build dataframe
80+
self._X_columns = X.columns
81+
self._X_dtypes = X.dtypes
82+
else:
83+
self._X_columns = None
84+
self._X_dtypes = None
85+
86+
if hasattr(y, "loc"):
87+
# store information to build a series
88+
self._y_name = y.name
89+
self._y_dtype = y.dtype
90+
else:
91+
self._y_name = None
92+
self._y_dtype = None
93+
8094
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
8195
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
8296
force_all_finite=False)

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,22 @@ def __init__(
8181
self.replacement = replacement
8282

8383
def _check_X_y(self, X, y):
84-
# store the columns name to reconstruct a dataframe
85-
self._columns = X.columns if hasattr(X, "loc") else None
84+
if hasattr(X, "loc"):
85+
# store information to build dataframe
86+
self._X_columns = X.columns
87+
self._X_dtypes = X.dtypes
88+
else:
89+
self._X_columns = None
90+
self._X_dtypes = None
91+
92+
if hasattr(y, "loc"):
93+
# store information to build a series
94+
self._y_name = y.name
95+
self._y_dtype = y.dtype
96+
else:
97+
self._y_name = None
98+
self._y_dtype = None
99+
86100
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
87101
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
88102
force_all_finite=False)

imblearn/utils/estimator_checks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def check_samplers_pandas(name, Sampler):
242242
random_state=0,
243243
)
244244
X_pd = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
245+
y_pd = pd.Series(y, name="class")
245246
sampler = Sampler()
246247
if isinstance(Sampler(), NearMiss):
247248
samplers = [Sampler(version=version) for version in (1, 2, 3)]
@@ -251,14 +252,16 @@ def check_samplers_pandas(name, Sampler):
251252

252253
for sampler in samplers:
253254
set_random_state(sampler)
254-
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y)
255+
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y_pd)
255256
X_res, y_res = sampler.fit_resample(X, y)
256257

257258
# check that we return a pandas dataframe if a dataframe was given in
258259
assert isinstance(X_res_pd, pd.DataFrame)
260+
assert isinstance(y_res_pd, pd.Series)
259261
assert X_pd.columns.to_list() == X_res_pd.columns.to_list()
262+
assert y_pd.name == y_res_pd.name
260263
assert_allclose(X_res_pd.to_numpy(), X_res)
261-
assert_allclose(y_res_pd, y_res)
264+
assert_allclose(y_res_pd.to_numpy(), y_res)
262265

263266

264267
def check_samplers_multiclass_ova(name, Sampler):

0 commit comments

Comments
 (0)