Skip to content

Commit 32eda46

Browse files
committed
iter
1 parent a6e975b commit 32eda46

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

imblearn/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def fit_resample(self, X, y):
107107
output = self._fit_resample(X, y)
108108

109109
if binarize_y:
110-
y_ = label_binarize(output[1], classes=np.unique(y))
110+
y_ = label_binarize(
111+
output[1], classes=list(self._classes_counts.keys())
112+
)
111113
else:
112114
y_ = output[1]
113115

@@ -291,7 +293,9 @@ def fit_resample(self, X, y):
291293

292294
if self.validate:
293295
if binarize_y:
294-
y_ = label_binarize(output[1], classes=np.unique(y))
296+
y_ = label_binarize(
297+
output[1], classes=list(self._classes_counts.keys())
298+
)
295299
else:
296300
y_ = output[1]
297301
X_, y_ = arrays_transformer.transform(output[0], y_)

imblearn/utils/estimator_checks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def check_samplers_dask_array(name, sampler_orig):
327327

328328
def check_samplers_dask_dataframe(name, sampler_orig):
329329
pytest.importorskip("dask")
330+
pd = pytest.importorskip("pandas")
330331
from dask import dataframe
331332
sampler = clone(sampler_orig)
332333
# Check that the samplers handle dask dataframe and dask series
@@ -342,20 +343,27 @@ def check_samplers_dask_dataframe(name, sampler_orig):
342343
)
343344
y_s = dataframe.from_array(y)
344345
y_s = y_s.rename("target")
346+
y_s_ohe = dataframe.get_dummies(
347+
y_s.astype(pd.CategoricalDtype(categories=[0, 1, 2]))
348+
)
345349

346350
for validate_if_dask_collection in (True, False):
347351
sampler.set_params(
348352
validate_if_dask_collection=validate_if_dask_collection
349353
)
350354
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
355+
# FIXME: not supported with validate=False
356+
X_res, y_res_s_ohe = sampler.fit_resample(X, y_s_ohe)
351357
X_res, y_res = sampler.fit_resample(X, y)
352358

353359
# check that we return the same type for dataframes or series types
354360
assert isinstance(X_res_df, dataframe.DataFrame)
355361
assert isinstance(y_res_s, dataframe.Series)
362+
assert isinstance(y_res_s_ohe, dataframe.DataFrame)
356363

357364
assert X_df.columns.to_list() == X_res_df.columns.to_list()
358365
assert y_s.name == y_res_s.name
366+
assert y_s_ohe.columns.to_list() == y_res_s_ohe.columns.to_list()
359367

360368
assert_allclose(np.array(X_res_df), X_res)
361369
assert_allclose(np.array(y_res_s), y_res)

0 commit comments

Comments
 (0)