Skip to content

Commit 20ba934

Browse files
committed
iter
1 parent f2d0ec0 commit 20ba934

File tree

3 files changed

+35
-9
lines changed

3 files changed

+35
-9
lines changed

imblearn/dask/tests/test_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@
88
from imblearn.dask.utils import type_of_target
99

1010

11-
def test_type_of_target_error():
12-
y = np.arange(10)
13-
14-
err_msg = "Expected a Dask array, series or dataframe."
15-
with pytest.raises(ValueError, match=err_msg):
16-
type_of_target(y)
17-
18-
1911
@pytest.mark.parametrize(
2012
"y, expected_result",
2113
[

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _more_tags(self):
140140
"2darray",
141141
"string",
142142
"dask-array",
143+
"dask-dataframe"
143144
],
144145
"sample_indices": True,
145146
"allow_nan": True,

imblearn/utils/estimator_checks.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def _yield_sampler_checks(sampler):
6161
yield check_samplers_pandas
6262
if "dask-array" in tags["X_types"]:
6363
yield check_samplers_dask_array
64+
if "dask-dataframe" in tags["X_types"]:
65+
yield check_samplers_dask_dataframe
6466
yield check_samplers_list
6567
yield check_samplers_multiclass_ova
6668
yield check_samplers_preserve_dtype
@@ -295,7 +297,7 @@ def check_samplers_pandas(name, sampler):
295297

296298
def check_samplers_dask_array(name, sampler):
297299
dask = pytest.importorskip("dask")
298-
# Check that the samplers handle pandas dataframe and pandas series
300+
# Check that the samplers handle dask array
299301
X, y = make_classification(
300302
n_samples=1000,
301303
n_classes=3,
@@ -317,6 +319,37 @@ def check_samplers_dask_array(name, sampler):
317319
assert_allclose(y_res_dask, y_res)
318320

319321

322+
def check_samplers_dask_dataframe(name, sampler):
323+
dask = pytest.importorskip("dask")
324+
# Check that the samplers handle dask dataframe and dask series
325+
X, y = make_classification(
326+
n_samples=1000,
327+
n_classes=3,
328+
n_informative=4,
329+
weights=[0.2, 0.3, 0.5],
330+
random_state=0,
331+
)
332+
X_df = dask.dataframe.from_array(
333+
X, columns=[str(i) for i in range(X.shape[1])]
334+
)
335+
y_s = dask.dataframe.from_array(y)
336+
337+
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
338+
X_res, y_res = sampler.fit_resample(X, y)
339+
340+
# check that we return the same type for dataframes or series types
341+
assert isinstance(X_res_df, dask.dataframe.DataFrame)
342+
assert isinstance(y_res_s, dask.dataframe.Series)
343+
344+
# assert X_df.columns.to_list() == X_res_df.columns.to_list()
345+
# assert y_df.columns.to_list() == y_res_df.columns.to_list()
346+
# assert y_s.name == y_res_s.name
347+
348+
# assert_allclose(X_res_df.to_numpy(), X_res)
349+
# assert_allclose(y_res_df.to_numpy().ravel(), y_res)
350+
# assert_allclose(y_res_s.to_numpy(), y_res)
351+
352+
320353
def check_samplers_list(name, sampler):
321354
# Check that the can samplers handle simple lists
322355
X, y = make_classification(

0 commit comments

Comments
 (0)