Skip to content

Commit b1d2d56

Browse files
committed
solve smotenc
1 parent e936bdd commit b1d2d56

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

imblearn/over_sampling/_smote.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -900,11 +900,12 @@ def __init__(
900900
)
901901
self.categorical_features = categorical_features
902902

903-
@staticmethod
904-
def _check_X_y(X, y):
903+
def _check_X_y(self, X, y):
905904
"""Overwrite the checking to let pass some string for categorical
906905
features.
907906
"""
907+
# store the columns name to reconstruct a dataframe
908+
self._columns = X.columns if hasattr(X, "loc") else None
908909
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
909910
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"], dtype=None)
910911
return X, y, binarize_y

imblearn/over_sampling/tests/test_smote_nc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from sklearn.datasets import make_classification
1515
from sklearn.utils._testing import assert_allclose
16+
from sklearn.utils._testing import assert_array_equal
1617

1718
from imblearn.over_sampling import SMOTENC
1819

@@ -184,7 +185,7 @@ def test_smotenc_pandas():
184185
smote = SMOTENC(categorical_features=categorical_features, random_state=0)
185186
X_res_pd, y_res_pd = smote.fit_resample(X_pd, y)
186187
X_res, y_res = smote.fit_resample(X, y)
187-
assert X_res_pd.tolist() == X_res.tolist()
188+
assert_array_equal(X_res_pd.to_numpy(), X_res)
188189
assert_allclose(y_res_pd, y_res)
189190

190191

0 commit comments

Comments
 (0)