Skip to content

Commit f356284

Browse files
authored
ENH allows pandas series in/out for the target (#647)
1 parent 158258e commit f356284

File tree

7 files changed

+101
-35
lines changed

7 files changed

+101
-35
lines changed

doc/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ The output will be of the following type:
3434

3535
* ``data_resampled``: array-like (2-D list, pandas.Dataframe, numpy.array) or
3636
sparse matrices;
37-
* ``targets_resampled``: 1-D numpy.array.
37+
* ``targets_resampled``: 1-D numpy.array or pd.Series.
3838

3939
.. topic:: Sparse input
4040

doc/whats_new/v0.6.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ Enhancement
5757
- :class:`imblearn.under_sampling.RandomUnderSampling`,
5858
:class:`imblearn.over_sampling.RandomOverSampling`,
5959
:class:`imblearn.datasets.make_imbalance` accepts Pandas DataFrame in and
60-
will output Pandas DataFrame.
60+
will output Pandas DataFrame. Similarly, it will accepts Pandas Series in and
61+
will output Pandas Series.
6162
:pr:`636` by :user:`Guillaume Lemaitre <glemaitre>`.
6263

6364
- :class:`imblearn.FunctionSampler` accepts a parameter ``validate`` allowing

imblearn/base.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,22 @@ def fit_resample(self, X, y):
8080

8181
output = self._fit_resample(X, y)
8282

83-
if self._columns is not None:
83+
if self._X_columns is not None or self._y_name is not None:
8484
import pandas as pd
85-
X_ = pd.DataFrame(output[0], columns=self._columns)
85+
86+
if self._X_columns is not None:
87+
X_ = pd.DataFrame(output[0], columns=self._X_columns)
88+
X_ = X_.astype(self._X_dtypes)
8689
else:
8790
X_ = output[0]
8891

89-
if binarize_y:
90-
y_sampled = label_binarize(output[1], np.unique(y))
91-
if len(output) == 2:
92-
return X_, y_sampled
93-
return X_, y_sampled, output[2]
94-
if len(output) == 2:
95-
return X_, output[1]
96-
return X_, output[1], output[2]
92+
y_ = (label_binarize(output[1], np.unique(y))
93+
if binarize_y else output[1])
94+
95+
if self._y_name is not None:
96+
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)
97+
98+
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
9799

98100
# define an alias for back-compatibility
99101
fit_sample = fit_resample
@@ -135,8 +137,22 @@ def __init__(self, sampling_strategy="auto"):
135137
self.sampling_strategy = sampling_strategy
136138

137139
def _check_X_y(self, X, y, accept_sparse=None):
138-
# store the columns name to reconstruct a dataframe
139-
self._columns = X.columns if hasattr(X, "loc") else None
140+
if hasattr(X, "loc"):
141+
# store information to build dataframe
142+
self._X_columns = X.columns
143+
self._X_dtypes = X.dtypes
144+
else:
145+
self._X_columns = None
146+
self._X_dtypes = None
147+
148+
if hasattr(y, "loc"):
149+
# store information to build a series
150+
self._y_name = y.name
151+
self._y_dtype = y.dtype
152+
else:
153+
self._y_name = None
154+
self._y_dtype = None
155+
140156
if accept_sparse is None:
141157
accept_sparse = ["csr", "csc"]
142158
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
@@ -263,20 +279,24 @@ def fit_resample(self, X, y):
263279

264280
output = self._fit_resample(X, y)
265281

266-
if self._columns is not None:
267-
import pandas as pd
268-
X_ = pd.DataFrame(output[0], columns=self._columns)
269-
else:
270-
X_ = output[0]
282+
if self.validate:
283+
if self._X_columns is not None or self._y_name is not None:
284+
import pandas as pd
271285

272-
if self.validate and binarize_y:
273-
y_sampled = label_binarize(output[1], np.unique(y))
274-
if len(output) == 2:
275-
return X_, y_sampled
276-
return X_, y_sampled, output[2]
277-
if len(output) == 2:
278-
return X_, output[1]
279-
return X_, output[1], output[2]
286+
if self._X_columns is not None:
287+
X_ = pd.DataFrame(output[0], columns=self._X_columns)
288+
X_ = X_.astype(self._X_dtypes)
289+
else:
290+
X_ = output[0]
291+
292+
y_ = (label_binarize(output[1], np.unique(y))
293+
if binarize_y else output[1])
294+
295+
if self._y_name is not None:
296+
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)
297+
298+
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
299+
return output
280300

281301
def _fit_resample(self, X, y):
282302
func = _identity if self.func is None else self.func

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,22 @@ def __init__(self, sampling_strategy="auto", random_state=None):
7575
self.random_state = random_state
7676

7777
def _check_X_y(self, X, y):
78-
# store the columns name to reconstruct a dataframe
79-
self._columns = X.columns if hasattr(X, "loc") else None
78+
if hasattr(X, "loc"):
79+
# store information to build dataframe
80+
self._X_columns = X.columns
81+
self._X_dtypes = X.dtypes
82+
else:
83+
self._X_columns = None
84+
self._X_dtypes = None
85+
86+
if hasattr(y, "loc"):
87+
# store information to build a series
88+
self._y_name = y.name
89+
self._y_dtype = y.dtype
90+
else:
91+
self._y_name = None
92+
self._y_dtype = None
93+
8094
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
8195
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
8296
force_all_finite=False)

imblearn/over_sampling/_smote.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,8 +892,22 @@ def _check_X_y(self, X, y):
892892
"""Overwrite the checking to let pass some string for categorical
893893
features.
894894
"""
895-
# store the columns name to reconstruct a dataframe
896-
self._columns = X.columns if hasattr(X, "loc") else None
895+
if hasattr(X, "loc"):
896+
# store information to build dataframe
897+
self._X_columns = X.columns
898+
self._X_dtypes = X.dtypes
899+
else:
900+
self._X_columns = None
901+
self._X_dtypes = None
902+
903+
if hasattr(y, "loc"):
904+
# store information to build a series
905+
self._y_name = y.name
906+
self._y_dtype = y.dtype
907+
else:
908+
self._y_name = None
909+
self._y_dtype = None
910+
897911
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
898912
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"], dtype=None)
899913
return X, y, binarize_y

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,22 @@ def __init__(
8181
self.replacement = replacement
8282

8383
def _check_X_y(self, X, y):
84-
# store the columns name to reconstruct a dataframe
85-
self._columns = X.columns if hasattr(X, "loc") else None
84+
if hasattr(X, "loc"):
85+
# store information to build dataframe
86+
self._X_columns = X.columns
87+
self._X_dtypes = X.dtypes
88+
else:
89+
self._X_columns = None
90+
self._X_dtypes = None
91+
92+
if hasattr(y, "loc"):
93+
# store information to build a series
94+
self._y_name = y.name
95+
self._y_dtype = y.dtype
96+
else:
97+
self._y_name = None
98+
self._y_dtype = None
99+
86100
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
87101
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
88102
force_all_finite=False)

imblearn/utils/estimator_checks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def check_samplers_pandas(name, Sampler):
242242
random_state=0,
243243
)
244244
X_pd = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
245+
y_pd = pd.Series(y, name="class")
245246
sampler = Sampler()
246247
if isinstance(Sampler(), NearMiss):
247248
samplers = [Sampler(version=version) for version in (1, 2, 3)]
@@ -251,14 +252,16 @@ def check_samplers_pandas(name, Sampler):
251252

252253
for sampler in samplers:
253254
set_random_state(sampler)
254-
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y)
255+
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y_pd)
255256
X_res, y_res = sampler.fit_resample(X, y)
256257

257258
# check that we return a pandas dataframe if a dataframe was given in
258259
assert isinstance(X_res_pd, pd.DataFrame)
260+
assert isinstance(y_res_pd, pd.Series)
259261
assert X_pd.columns.to_list() == X_res_pd.columns.to_list()
262+
assert y_pd.name == y_res_pd.name
260263
assert_allclose(X_res_pd.to_numpy(), X_res)
261-
assert_allclose(y_res_pd, y_res)
264+
assert_allclose(y_res_pd.to_numpy(), y_res)
262265

263266

264267
def check_samplers_multiclass_ova(name, Sampler):

0 commit comments

Comments
 (0)